From a846e38fbdeb77aad23282a8068b08e4cd6b3921 Mon Sep 17 00:00:00 2001 From: Sara Date: Fri, 17 Nov 2023 13:38:32 +0100 Subject: [PATCH] fix waveform in pipeline --- .../reflector/pipelines/main_live_pipeline.py | 15 +++++++++------ .../processors/audio_waveform_processor.py | 5 ++++- server/tests/test_transcripts_audio_download.py | 12 ------------ server/tests/test_transcripts_rtc_ws.py | 17 ++++++++++++----- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index b0576b92..fece6da5 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -233,6 +233,7 @@ class PipelineMainBase(PipelineRunner): data=final_short_summary, ) + @broadcast_to_sockets async def on_duration(self, data): async with self.transaction(): duration = TranscriptDuration(duration=data) @@ -248,14 +249,16 @@ class PipelineMainBase(PipelineRunner): transcript=transcript, event="DURATION", data=duration ) + @broadcast_to_sockets async def on_waveform(self, data): - waveform = TranscriptWaveform(waveform=data) + async with self.transaction(): + waveform = TranscriptWaveform(waveform=data) - transcript = await self.get_transcript() + transcript = await self.get_transcript() - return await transcripts_controller.append_event( - transcript=transcript, event="WAVEFORM", data=waveform - ) + return await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform + ) class PipelineMainLive(PipelineMainBase): @@ -283,7 +286,7 @@ class PipelineMainLive(PipelineMainBase): BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - AudioWaveformProcessor( + AudioWaveformProcessor.as_threaded( audio_path=transcript.audio_mp3_filename, waveform_path=transcript.audio_waveform_filename, on_waveform=self.on_waveform, diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py index acce904a..f1a24ffd 100644 --- a/server/reflector/processors/audio_waveform_processor.py +++ b/server/reflector/processors/audio_waveform_processor.py @@ -22,7 +22,7 @@ class AudioWaveformProcessor(Processor): self.audio_path = audio_path self.waveform_path = waveform_path - async def _push(self, _data): + async def _flush(self): self.waveform_path.parent.mkdir(parents=True, exist_ok=True) self.logger.info("Waveform Processing Started") waveform = get_audio_waveform(path=self.audio_path, segments_count=255) @@ -31,3 +31,6 @@ class AudioWaveformProcessor(Processor): json.dump(waveform, fd) self.logger.info("Waveform Processing Finished") await self.emit(waveform, name="waveform") + + async def _push(_self, _data): + return diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index 69ae5f65..28f83fff 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek( assert response.status_code == 206 assert response.headers["content-type"] == content_type assert response.headers["content-range"].startswith("bytes 100-") - - -@pytest.mark.asyncio -async def test_transcript_audio_download_waveform(fake_transcript): - from reflector.app import app - - ac = AsyncClient(app=app, base_url="http://test/v1") - response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform") - assert response.status_code == 200 - assert response.headers["content-type"] == "application/json" - assert isinstance(response.json()["data"], list) - assert len(response.json()["data"]) >= 255 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 65660a5e..b33b1db5 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -183,8 +183,14 @@ async def test_transcript_rtc_and_websocket( assert ev["data"]["title"] == "LLM TITLE" assert "WAVEFORM" in eventnames - ev = events[eventnames.index("FINAL_TITLE")] - assert ev["data"]["title"] == "LLM TITLE" + ev = events[eventnames.index("WAVEFORM")] + assert isinstance(ev["data"]["waveform"], list) + assert len(ev["data"]["waveform"]) >= 250 + waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform") + assert waveform_resp.status_code == 200 + assert waveform_resp.headers["content-type"] == "application/json" + assert isinstance(waveform_resp.json()["data"], list) + assert len(waveform_resp.json()["data"]) >= 250 # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] @@ -197,11 +203,12 @@ async def test_transcript_rtc_and_websocket( # check on the latest response that the audio duration is > 0 assert resp.json()["duration"] > 0 + assert "DURATION" in eventnames # check that audio/mp3 is available - resp = await ac.get(f"/transcripts/{tid}/audio/mp3") - assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/mpeg" + audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3") + assert audio_resp.status_code == 200 + assert audio_resp.headers["Content-Type"] == "audio/mpeg" @pytest.mark.usefixtures("celery_session_app")