From b217c7ba41d396307c6095508593943384970f32 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 23 Sep 2025 16:53:34 -0600 Subject: [PATCH] refactor: use @with_session decorator in file pipeline tasks - Add @with_session decorator to shared tasks in main_file_pipeline.py - Update task_send_webhook_if_needed and task_pipeline_file_process to use session parameter - Refactor PipelineMainFile methods to accept session as parameter - Pass session through method calls instead of creating new sessions with get_session_factory() This improves session management consistency and follows the pattern established by other worker tasks in the codebase. --- .../reflector/pipelines/main_file_pipeline.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index 08915f99..cf586350 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -54,6 +54,7 @@ from reflector.processors.types import ( ) from reflector.settings import settings from reflector.storage import get_transcripts_storage +from reflector.worker.session_decorator import with_session from reflector.worker.webhook import send_transcript_webhook @@ -130,6 +131,7 @@ class PipelineMainFile(PipelineMainBase): # Run parallel processing await self.run_parallel_processing( + session, audio_path, audio_url, transcript.source_language, @@ -201,6 +203,7 @@ class PipelineMainFile(PipelineMainBase): async def run_parallel_processing( self, + session, audio_path: Path, audio_url: str, source_language: str, @@ -214,7 +217,7 @@ class PipelineMainFile(PipelineMainBase): # Phase 1: Parallel processing of independent tasks transcription_task = self.transcribe_file(audio_url, source_language) diarization_task = self.diarize_file(audio_url) - waveform_task = self.generate_waveform(audio_path) + waveform_task = self.generate_waveform(session, audio_path) results = await asyncio.gather( transcription_task, diarization_task, waveform_task, return_exceptions=True @@ -262,7 +265,7 @@ class PipelineMainFile(PipelineMainBase): ) results = await asyncio.gather( self.generate_title(topics), - self.generate_summaries(topics), + self.generate_summaries(session, topics), return_exceptions=True, ) @@ -372,16 +375,13 @@ class PipelineMainFile(PipelineMainBase): await processor.flush() - async def generate_summaries(self, topics: list[TitleSummary]): + async def generate_summaries(self, session, topics: list[TitleSummary]): """Generate long and short summaries from topics""" if not topics: self.logger.warning("No topics for summary generation") return - async with get_session_factory()() as session: - transcript = await transcripts_controller.get_by_id( - session, self.transcript_id - ) + transcript = await transcripts_controller.get_by_id(session, self.transcript_id) processor = TranscriptFinalSummaryProcessor( transcript=transcript, callback=self.on_long_summary, @@ -397,35 +397,35 @@ class PipelineMainFile(PipelineMainBase): @shared_task @asynctask -async def task_send_webhook_if_needed(*, transcript_id: str): +@with_session +async def task_send_webhook_if_needed(session, *, transcript_id: str): """Send webhook if this is a room recording with webhook configured""" - async with get_session_factory()() as session: - transcript = await transcripts_controller.get_by_id(session, transcript_id) - if not transcript: - return + transcript = await transcripts_controller.get_by_id(session, transcript_id) + if not transcript: + return - if transcript.source_kind == SourceKind.ROOM and transcript.room_id: - room = await rooms_controller.get_by_id(session, transcript.room_id) - if room and room.webhook_url: - logger.info( - "Dispatching webhook", - transcript_id=transcript_id, - room_id=room.id, - webhook_url=room.webhook_url, - ) - send_transcript_webhook.delay( - transcript_id, room.id, event_id=uuid.uuid4().hex - ) + if transcript.source_kind == SourceKind.ROOM and transcript.room_id: + room = await rooms_controller.get_by_id(session, transcript.room_id) + if room and room.webhook_url: + logger.info( + "Dispatching webhook", + transcript_id=transcript_id, + room_id=room.id, + webhook_url=room.webhook_url, + ) + send_transcript_webhook.delay( + transcript_id, room.id, event_id=uuid.uuid4().hex + ) @shared_task @asynctask -async def task_pipeline_file_process(*, transcript_id: str): +@with_session +async def task_pipeline_file_process(session, *, transcript_id: str): """Celery task for file pipeline processing""" - async with get_session_factory()() as session: - transcript = await transcripts_controller.get_by_id(session, transcript_id) - if not transcript: - raise Exception(f"Transcript {transcript_id} not found") + transcript = await transcripts_controller.get_by_id(session, transcript_id) + if not transcript: + raise Exception(f"Transcript {transcript_id} not found") pipeline = PipelineMainFile(transcript_id=transcript_id) try: