refactor: use @with_session decorator in file pipeline tasks

- Add @with_session decorator to shared tasks in main_file_pipeline.py
- Update task_send_webhook_if_needed and task_pipeline_file_process to use session parameter
- Refactor PipelineMainFile methods to accept session as parameter
- Pass session through method calls instead of creating new sessions with get_session_factory()

This improves session management consistency and follows the pattern established
by other worker tasks in the codebase.
This commit is contained in:
2025-09-23 16:53:34 -06:00
parent 0b2152ea75
commit b217c7ba41

View File

@@ -54,6 +54,7 @@ from reflector.processors.types import (
) )
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_transcripts_storage from reflector.storage import get_transcripts_storage
from reflector.worker.session_decorator import with_session
from reflector.worker.webhook import send_transcript_webhook from reflector.worker.webhook import send_transcript_webhook
@@ -130,6 +131,7 @@ class PipelineMainFile(PipelineMainBase):
# Run parallel processing # Run parallel processing
await self.run_parallel_processing( await self.run_parallel_processing(
session,
audio_path, audio_path,
audio_url, audio_url,
transcript.source_language, transcript.source_language,
@@ -201,6 +203,7 @@ class PipelineMainFile(PipelineMainBase):
async def run_parallel_processing( async def run_parallel_processing(
self, self,
session,
audio_path: Path, audio_path: Path,
audio_url: str, audio_url: str,
source_language: str, source_language: str,
@@ -214,7 +217,7 @@ class PipelineMainFile(PipelineMainBase):
# Phase 1: Parallel processing of independent tasks # Phase 1: Parallel processing of independent tasks
transcription_task = self.transcribe_file(audio_url, source_language) transcription_task = self.transcribe_file(audio_url, source_language)
diarization_task = self.diarize_file(audio_url) diarization_task = self.diarize_file(audio_url)
waveform_task = self.generate_waveform(audio_path) waveform_task = self.generate_waveform(session, audio_path)
results = await asyncio.gather( results = await asyncio.gather(
transcription_task, diarization_task, waveform_task, return_exceptions=True transcription_task, diarization_task, waveform_task, return_exceptions=True
@@ -262,7 +265,7 @@ class PipelineMainFile(PipelineMainBase):
) )
results = await asyncio.gather( results = await asyncio.gather(
self.generate_title(topics), self.generate_title(topics),
self.generate_summaries(topics), self.generate_summaries(session, topics),
return_exceptions=True, return_exceptions=True,
) )
@@ -372,16 +375,13 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush() await processor.flush()
async def generate_summaries(self, topics: list[TitleSummary]): async def generate_summaries(self, session, topics: list[TitleSummary]):
"""Generate long and short summaries from topics""" """Generate long and short summaries from topics"""
if not topics: if not topics:
self.logger.warning("No topics for summary generation") self.logger.warning("No topics for summary generation")
return return
async with get_session_factory()() as session: transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = TranscriptFinalSummaryProcessor( processor = TranscriptFinalSummaryProcessor(
transcript=transcript, transcript=transcript,
callback=self.on_long_summary, callback=self.on_long_summary,
@@ -397,9 +397,9 @@ class PipelineMainFile(PipelineMainBase):
@shared_task @shared_task
@asynctask @asynctask
async def task_send_webhook_if_needed(*, transcript_id: str): @with_session
async def task_send_webhook_if_needed(session, *, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured""" """Send webhook if this is a room recording with webhook configured"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
return return
@@ -420,9 +420,9 @@ async def task_send_webhook_if_needed(*, transcript_id: str):
@shared_task @shared_task
@asynctask @asynctask
async def task_pipeline_file_process(*, transcript_id: str): @with_session
async def task_pipeline_file_process(session, *, transcript_id: str):
"""Celery task for file pipeline processing""" """Celery task for file pipeline processing"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")