From 9dfd76996f851cc52be54feea078adbc0816dc57 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 29 Aug 2025 00:58:14 -0600 Subject: [PATCH] fix: file pipeline status reporting and websocket updates (#589) * feat: use file pipeline for upload and reprocess action * fix: make file pipeline correctly report status events * fix: duplication of transcripts_controller * fix: tests * test: fix file upload test * test: fix reprocess * fix: also patch from main_file_pipeline (how patch is done is dependent of file import unfortunately) --- server/reflector/db/transcripts.py | 33 ++++++++- .../reflector/pipelines/main_file_pipeline.py | 51 +++++++++++--- .../reflector/pipelines/main_live_pipeline.py | 32 ++++----- server/reflector/views/transcripts_process.py | 4 +- server/reflector/views/transcripts_upload.py | 4 +- server/tests/conftest.py | 68 ++++++++++++++++++- .../tests/test_transcripts_audio_download.py | 2 +- server/tests/test_transcripts_process.py | 15 ++-- server/tests/test_transcripts_upload.py | 11 +-- 9 files changed, 170 insertions(+), 50 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 9dbcba9f..47148995 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -122,6 +122,15 @@ def generate_transcript_name() -> str: return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +TranscriptStatus = Literal[ + "idle", "uploaded", "recording", "processing", "error", "ended" +] + + +class StrValue(BaseModel): + value: str + + class AudioWaveform(BaseModel): data: list[float] @@ -185,7 +194,7 @@ class Transcript(BaseModel): id: str = Field(default_factory=generate_uuid4) user_id: str | None = None name: str = Field(default_factory=generate_transcript_name) - status: str = "idle" + status: TranscriptStatus = "idle" duration: float = 0 created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) title: str | None = None @@ -732,5 +741,27 @@ class TranscriptController: transcript.delete_participant(participant_id) await self.update(transcript, {"participants": transcript.participants_dump()}) + async def set_status( + self, transcript_id: str, status: TranscriptStatus + ) -> TranscriptEvent | None: + """ + Update the status of a transcript + + Will add an event STATUS + update the status field of transcript + """ + async with self.transaction(): + transcript = await self.get_by_id(transcript_id) + if not transcript: + raise Exception(f"Transcript {transcript_id} not found") + if transcript.status == status: + return + resp = await self.append_event( + transcript=transcript, + event="STATUS", + data=StrValue(value=status), + ) + await self.update(transcript, {"status": status}) + return resp + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index f2c8fb85..f11cddca 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -15,10 +15,15 @@ from celery import shared_task from reflector.db.transcripts import ( Transcript, + TranscriptStatus, transcripts_controller, ) from reflector.logger import logger -from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask +from reflector.pipelines.main_live_pipeline import ( + PipelineMainBase, + asynctask, + broadcast_to_sockets, +) from reflector.processors import ( AudioFileWriterProcessor, TranscriptFinalSummaryProcessor, @@ -83,12 +88,27 @@ class PipelineMainFile(PipelineMainBase): exc_info=result, ) + @broadcast_to_sockets + async def set_status(self, transcript_id: str, status: TranscriptStatus): + async with self.lock_transaction(): + return await transcripts_controller.set_status(transcript_id, status) + async def process(self, file_path: Path): """Main entry point for file processing""" self.logger.info(f"Starting file pipeline for {file_path}") transcript = await self.get_transcript() + # Clear transcript as we're going to regenerate everything + async with self.transaction(): + await transcripts_controller.update( + transcript, + { + "events": [], + "topics": [], + }, + ) + # Extract audio and write to transcript location audio_path = await self.extract_and_write_audio(file_path, transcript) @@ -105,6 +125,8 @@ class PipelineMainFile(PipelineMainBase): self.logger.info("File pipeline complete") + await transcripts_controller.set_status(transcript.id, "ended") + async def extract_and_write_audio( self, file_path: Path, transcript: Transcript ) -> Path: @@ -362,14 +384,21 @@ async def task_pipeline_file_process(*, transcript_id: str): if not transcript: raise Exception(f"Transcript {transcript_id} not found") - # Find the file to process - audio_file = next(transcript.data_path.glob("upload.*"), None) - if not audio_file: - audio_file = next(transcript.data_path.glob("audio.*"), None) - - if not audio_file: - raise Exception("No audio file found to process") - - # Run file pipeline pipeline = PipelineMainFile(transcript_id=transcript_id) - await pipeline.process(audio_file) + + try: + await pipeline.set_status(transcript_id, "processing") + + # Find the file to process + audio_file = next(transcript.data_path.glob("upload.*"), None) + if not audio_file: + audio_file = next(transcript.data_path.glob("audio.*"), None) + + if not audio_file: + raise Exception("No audio file found to process") + + await pipeline.process(audio_file) + + except Exception: + await pipeline.set_status(transcript_id, "error") + raise diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 812847db..30c8777b 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -32,6 +32,7 @@ from reflector.db.transcripts import ( TranscriptFinalLongSummary, TranscriptFinalShortSummary, TranscriptFinalTitle, + TranscriptStatus, TranscriptText, TranscriptTopic, TranscriptWaveform, @@ -188,8 +189,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] ] @asynccontextmanager - async def transaction(self): + async def lock_transaction(self): + # This lock is to prevent multiple processor starting adding + # into event array at the same time async with self._lock: + yield + + @asynccontextmanager + async def transaction(self): + async with self.lock_transaction(): async with transcripts_controller.transaction(): yield @@ -198,14 +206,14 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] # if it's the first part, update the status of the transcript # but do not set the ended status yet. if isinstance(self, PipelineMainLive): - status_mapping = { + status_mapping: dict[str, TranscriptStatus] = { "started": "recording", "push": "recording", "flush": "processing", "error": "error", } elif isinstance(self, PipelineMainFinalSummaries): - status_mapping = { + status_mapping: dict[str, TranscriptStatus] = { "push": "processing", "flush": "processing", "error": "error", @@ -221,22 +229,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] return # when the status of the pipeline changes, update the transcript - async with self.transaction(): - transcript = await self.get_transcript() - if status == transcript.status: - return - resp = await transcripts_controller.append_event( - transcript=transcript, - event="STATUS", - data=StrValue(value=status), - ) - await transcripts_controller.update( - transcript, - { - "status": status, - }, - ) - return resp + async with self._lock: + return await transcripts_controller.set_status(self.transcript_id, status) @broadcast_to_sockets async def on_transcript(self, data): diff --git a/server/reflector/views/transcripts_process.py b/server/reflector/views/transcripts_process.py index 8f6d3ab6..0200e7f8 100644 --- a/server/reflector/views/transcripts_process.py +++ b/server/reflector/views/transcripts_process.py @@ -6,7 +6,7 @@ from pydantic import BaseModel import reflector.auth as auth from reflector.db.transcripts import transcripts_controller -from reflector.pipelines.main_live_pipeline import task_pipeline_process +from reflector.pipelines.main_file_pipeline import task_pipeline_file_process router = APIRouter() @@ -40,7 +40,7 @@ async def transcript_process( return ProcessStatus(status="already running") # schedule a background task process the file - task_pipeline_process.delay(transcript_id=transcript_id) + task_pipeline_file_process.delay(transcript_id=transcript_id) return ProcessStatus(status="ok") diff --git a/server/reflector/views/transcripts_upload.py b/server/reflector/views/transcripts_upload.py index 18e75dac..8efbc274 100644 --- a/server/reflector/views/transcripts_upload.py +++ b/server/reflector/views/transcripts_upload.py @@ -6,7 +6,7 @@ from pydantic import BaseModel import reflector.auth as auth from reflector.db.transcripts import transcripts_controller -from reflector.pipelines.main_live_pipeline import task_pipeline_process +from reflector.pipelines.main_file_pipeline import task_pipeline_file_process router = APIRouter() @@ -92,6 +92,6 @@ async def transcript_record_upload( await transcripts_controller.update(transcript, {"status": "uploaded"}) # launch a background task to process the file - task_pipeline_process.delay(transcript_id=transcript_id) + task_pipeline_file_process.delay(transcript_id=transcript_id) return UploadStatus(status="ok") diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d739751d..22fe4193 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -178,6 +178,63 @@ async def dummy_diarization(): yield +@pytest.fixture +async def dummy_file_transcript(): + from reflector.processors.file_transcript import FileTranscriptProcessor + from reflector.processors.types import Transcript, Word + + class TestFileTranscriptProcessor(FileTranscriptProcessor): + async def _transcript(self, data): + return Transcript( + text="Hello world. How are you today?", + words=[ + Word(start=0.0, end=0.5, text="Hello", speaker=0), + Word(start=0.5, end=0.6, text=" ", speaker=0), + Word(start=0.6, end=1.0, text="world", speaker=0), + Word(start=1.0, end=1.1, text=".", speaker=0), + Word(start=1.1, end=1.2, text=" ", speaker=0), + Word(start=1.2, end=1.5, text="How", speaker=0), + Word(start=1.5, end=1.6, text=" ", speaker=0), + Word(start=1.6, end=1.8, text="are", speaker=0), + Word(start=1.8, end=1.9, text=" ", speaker=0), + Word(start=1.9, end=2.1, text="you", speaker=0), + Word(start=2.1, end=2.2, text=" ", speaker=0), + Word(start=2.2, end=2.5, text="today", speaker=0), + Word(start=2.5, end=2.6, text="?", speaker=0), + ], + ) + + with patch( + "reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__" + ) as mock_auto: + mock_auto.return_value = TestFileTranscriptProcessor() + yield + + +@pytest.fixture +async def dummy_file_diarization(): + from reflector.processors.file_diarization import ( + FileDiarizationOutput, + FileDiarizationProcessor, + ) + from reflector.processors.types import DiarizationSegment + + class TestFileDiarizationProcessor(FileDiarizationProcessor): + async def _diarize(self, data): + return FileDiarizationOutput( + diarization=[ + DiarizationSegment(start=0.0, end=1.1, speaker=0), + DiarizationSegment(start=1.2, end=2.6, speaker=1), + ] + ) + + with patch( + "reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__" + ) as mock_auto: + mock_auto.return_value = TestFileDiarizationProcessor() + yield + + @pytest.fixture async def dummy_transcript_translator(): from reflector.processors.transcript_translator import TranscriptTranslatorProcessor @@ -238,9 +295,13 @@ async def dummy_storage(): with ( patch("reflector.storage.base.Storage.get_instance") as mock_storage, patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts, + patch( + "reflector.pipelines.main_file_pipeline.get_transcripts_storage" + ) as mock_get_transcripts2, ): mock_storage.return_value = dummy mock_get_transcripts.return_value = dummy + mock_get_transcripts2.return_value = dummy yield @@ -260,7 +321,10 @@ def celery_config(): @pytest.fixture(scope="session") def celery_includes(): - return ["reflector.pipelines.main_live_pipeline"] + return [ + "reflector.pipelines.main_live_pipeline", + "reflector.pipelines.main_file_pipeline", + ] @pytest.fixture @@ -302,7 +366,7 @@ async def fake_transcript_with_topics(tmpdir, client): transcript = await transcripts_controller.get_by_id(tid) assert transcript is not None - await transcripts_controller.update(transcript, {"status": "finished"}) + await transcripts_controller.update(transcript, {"status": "ended"}) # manually copy a file at the expected location audio_filename = transcript.audio_mp3_filename diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index 81b74def..e40d0ade 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -19,7 +19,7 @@ async def fake_transcript(tmpdir, client): transcript = await transcripts_controller.get_by_id(tid) assert transcript is not None - await transcripts_controller.update(transcript, {"status": "finished"}) + await transcripts_controller.update(transcript, {"status": "ended"}) # manually copy a file at the expected location audio_filename = transcript.audio_mp3_filename diff --git a/server/tests/test_transcripts_process.py b/server/tests/test_transcripts_process.py index 3551d718..5f45cf4b 100644 --- a/server/tests/test_transcripts_process.py +++ b/server/tests/test_transcripts_process.py @@ -29,10 +29,10 @@ async def client(app_lifespan): @pytest.mark.asyncio async def test_transcript_process( tmpdir, - whisper_transcript, dummy_llm, dummy_processors, - dummy_diarization, + dummy_file_transcript, + dummy_file_diarization, dummy_storage, client, ): @@ -56,8 +56,8 @@ async def test_transcript_process( assert response.status_code == 200 assert response.json()["status"] == "ok" - # wait for processing to finish (max 10 minutes) - timeout_seconds = 600 # 10 minutes + # wait for processing to finish (max 1 minute) + timeout_seconds = 60 start_time = time.monotonic() while (time.monotonic() - start_time) < timeout_seconds: # fetch the transcript and check if it is ended @@ -75,9 +75,10 @@ async def test_transcript_process( ) assert response.status_code == 200 assert response.json()["status"] == "ok" + await asyncio.sleep(2) - # wait for processing to finish (max 10 minutes) - timeout_seconds = 600 # 10 minutes + # wait for processing to finish (max 1 minute) + timeout_seconds = 60 start_time = time.monotonic() while (time.monotonic() - start_time) < timeout_seconds: # fetch the transcript and check if it is ended @@ -99,4 +100,4 @@ async def test_transcript_process( response = await client.get(f"/transcripts/{tid}/topics") assert response.status_code == 200 assert len(response.json()) == 1 - assert "want to share" in response.json()[0]["transcript"] + assert "Hello world. How are you today?" in response.json()[0]["transcript"] diff --git a/server/tests/test_transcripts_upload.py b/server/tests/test_transcripts_upload.py index ee08b1be..e9a90c7a 100644 --- a/server/tests/test_transcripts_upload.py +++ b/server/tests/test_transcripts_upload.py @@ -12,7 +12,8 @@ async def test_transcript_upload_file( tmpdir, dummy_llm, dummy_processors, - dummy_diarization, + dummy_file_transcript, + dummy_file_diarization, dummy_storage, client, ): @@ -36,8 +37,8 @@ async def test_transcript_upload_file( assert response.status_code == 200 assert response.json()["status"] == "ok" - # wait the processing to finish (max 10 minutes) - timeout_seconds = 600 # 10 minutes + # wait the processing to finish (max 1 minute) + timeout_seconds = 60 start_time = time.monotonic() while (time.monotonic() - start_time) < timeout_seconds: # fetch the transcript and check if it is ended @@ -47,7 +48,7 @@ async def test_transcript_upload_file( break await asyncio.sleep(1) else: - pytest.fail(f"Processing timed out after {timeout_seconds} seconds") + return pytest.fail(f"Processing timed out after {timeout_seconds} seconds") # check the transcript is ended transcript = resp.json() @@ -59,4 +60,4 @@ async def test_transcript_upload_file( response = await client.get(f"/transcripts/{tid}/topics") assert response.status_code == 200 assert len(response.json()) == 1 - assert "want to share" in response.json()[0]["transcript"] + assert "Hello world. How are you today?" in response.json()[0]["transcript"]