mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
refactor: use @with_session decorator in file pipeline tasks
- Add @with_session decorator to shared tasks in main_file_pipeline.py - Update task_send_webhook_if_needed and task_pipeline_file_process to use session parameter - Refactor PipelineMainFile methods to accept session as parameter - Pass session through method calls instead of creating new sessions with get_session_factory() This improves session management consistency and follows the pattern established by other worker tasks in the codebase.
This commit is contained in:
@@ -54,6 +54,7 @@ from reflector.processors.types import (
|
||||
)
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.worker.session_decorator import with_session
|
||||
from reflector.worker.webhook import send_transcript_webhook
|
||||
|
||||
|
||||
@@ -130,6 +131,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
# Run parallel processing
|
||||
await self.run_parallel_processing(
|
||||
session,
|
||||
audio_path,
|
||||
audio_url,
|
||||
transcript.source_language,
|
||||
@@ -201,6 +203,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
async def run_parallel_processing(
|
||||
self,
|
||||
session,
|
||||
audio_path: Path,
|
||||
audio_url: str,
|
||||
source_language: str,
|
||||
@@ -214,7 +217,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
# Phase 1: Parallel processing of independent tasks
|
||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||
diarization_task = self.diarize_file(audio_url)
|
||||
waveform_task = self.generate_waveform(audio_path)
|
||||
waveform_task = self.generate_waveform(session, audio_path)
|
||||
|
||||
results = await asyncio.gather(
|
||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||
@@ -262,7 +265,7 @@ class PipelineMainFile(PipelineMainBase):
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
self.generate_title(topics),
|
||||
self.generate_summaries(topics),
|
||||
self.generate_summaries(session, topics),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
@@ -372,16 +375,13 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||
async def generate_summaries(self, session, topics: list[TitleSummary]):
|
||||
"""Generate long and short summaries from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
@@ -397,35 +397,35 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_send_webhook_if_needed(*, transcript_id: str):
|
||||
@with_session
|
||||
async def task_send_webhook_if_needed(session, *, transcript_id: str):
|
||||
"""Send webhook if this is a room recording with webhook configured"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_file_process(*, transcript_id: str):
|
||||
@with_session
|
||||
async def task_pipeline_file_process(session, *, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user