refactor: use @with_session decorator in file pipeline tasks

- Add @with_session decorator to shared tasks in main_file_pipeline.py
- Update task_send_webhook_if_needed and task_pipeline_file_process to use session parameter
- Refactor PipelineMainFile methods to accept session as parameter
- Pass session through method calls instead of creating new sessions with get_session_factory()

This improves session management consistency and follows the pattern established
by other worker tasks in the codebase.
This commit is contained in:
2025-09-23 16:53:34 -06:00
parent 0b2152ea75
commit b217c7ba41

View File

@@ -54,6 +54,7 @@ from reflector.processors.types import (
) )
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_transcripts_storage from reflector.storage import get_transcripts_storage
from reflector.worker.session_decorator import with_session
from reflector.worker.webhook import send_transcript_webhook from reflector.worker.webhook import send_transcript_webhook
@@ -130,6 +131,7 @@ class PipelineMainFile(PipelineMainBase):
# Run parallel processing # Run parallel processing
await self.run_parallel_processing( await self.run_parallel_processing(
session,
audio_path, audio_path,
audio_url, audio_url,
transcript.source_language, transcript.source_language,
@@ -201,6 +203,7 @@ class PipelineMainFile(PipelineMainBase):
async def run_parallel_processing( async def run_parallel_processing(
self, self,
session,
audio_path: Path, audio_path: Path,
audio_url: str, audio_url: str,
source_language: str, source_language: str,
@@ -214,7 +217,7 @@ class PipelineMainFile(PipelineMainBase):
# Phase 1: Parallel processing of independent tasks # Phase 1: Parallel processing of independent tasks
transcription_task = self.transcribe_file(audio_url, source_language) transcription_task = self.transcribe_file(audio_url, source_language)
diarization_task = self.diarize_file(audio_url) diarization_task = self.diarize_file(audio_url)
waveform_task = self.generate_waveform(audio_path) waveform_task = self.generate_waveform(session, audio_path)
results = await asyncio.gather( results = await asyncio.gather(
transcription_task, diarization_task, waveform_task, return_exceptions=True transcription_task, diarization_task, waveform_task, return_exceptions=True
@@ -262,7 +265,7 @@ class PipelineMainFile(PipelineMainBase):
) )
results = await asyncio.gather( results = await asyncio.gather(
self.generate_title(topics), self.generate_title(topics),
self.generate_summaries(topics), self.generate_summaries(session, topics),
return_exceptions=True, return_exceptions=True,
) )
@@ -372,16 +375,13 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush() await processor.flush()
async def generate_summaries(self, topics: list[TitleSummary]): async def generate_summaries(self, session, topics: list[TitleSummary]):
"""Generate long and short summaries from topics""" """Generate long and short summaries from topics"""
if not topics: if not topics:
self.logger.warning("No topics for summary generation") self.logger.warning("No topics for summary generation")
return return
async with get_session_factory()() as session: transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = TranscriptFinalSummaryProcessor( processor = TranscriptFinalSummaryProcessor(
transcript=transcript, transcript=transcript,
callback=self.on_long_summary, callback=self.on_long_summary,
@@ -397,35 +397,35 @@ class PipelineMainFile(PipelineMainBase):
@shared_task @shared_task
@asynctask @asynctask
async def task_send_webhook_if_needed(*, transcript_id: str): @with_session
async def task_send_webhook_if_needed(session, *, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured""" """Send webhook if this is a room recording with webhook configured"""
async with get_session_factory()() as session: transcript = await transcripts_controller.get_by_id(session, transcript_id)
transcript = await transcripts_controller.get_by_id(session, transcript_id) if not transcript:
if not transcript: return
return
if transcript.source_kind == SourceKind.ROOM and transcript.room_id: if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(session, transcript.room_id) room = await rooms_controller.get_by_id(session, transcript.room_id)
if room and room.webhook_url: if room and room.webhook_url:
logger.info( logger.info(
"Dispatching webhook", "Dispatching webhook",
transcript_id=transcript_id, transcript_id=transcript_id,
room_id=room.id, room_id=room.id,
webhook_url=room.webhook_url, webhook_url=room.webhook_url,
) )
send_transcript_webhook.delay( send_transcript_webhook.delay(
transcript_id, room.id, event_id=uuid.uuid4().hex transcript_id, room.id, event_id=uuid.uuid4().hex
) )
@shared_task @shared_task
@asynctask @asynctask
async def task_pipeline_file_process(*, transcript_id: str): @with_session
async def task_pipeline_file_process(session, *, transcript_id: str):
"""Celery task for file pipeline processing""" """Celery task for file pipeline processing"""
async with get_session_factory()() as session: transcript = await transcripts_controller.get_by_id(session, transcript_id)
transcript = await transcripts_controller.get_by_id(session, transcript_id) if not transcript:
if not transcript: raise Exception(f"Transcript {transcript_id} not found")
raise Exception(f"Transcript {transcript_id} not found")
pipeline = PipelineMainFile(transcript_id=transcript_id) pipeline = PipelineMainFile(transcript_id=transcript_id)
try: try: