mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
refactor: use @with_session decorator in file pipeline tasks
- Add @with_session decorator to shared tasks in main_file_pipeline.py - Update task_send_webhook_if_needed and task_pipeline_file_process to use session parameter - Refactor PipelineMainFile methods to accept session as parameter - Pass session through method calls instead of creating new sessions with get_session_factory() This improves session management consistency and follows the pattern established by other worker tasks in the codebase.
This commit is contained in:
@@ -54,6 +54,7 @@ from reflector.processors.types import (
|
|||||||
)
|
)
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage import get_transcripts_storage
|
from reflector.storage import get_transcripts_storage
|
||||||
|
from reflector.worker.session_decorator import with_session
|
||||||
from reflector.worker.webhook import send_transcript_webhook
|
from reflector.worker.webhook import send_transcript_webhook
|
||||||
|
|
||||||
|
|
||||||
@@ -130,6 +131,7 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
# Run parallel processing
|
# Run parallel processing
|
||||||
await self.run_parallel_processing(
|
await self.run_parallel_processing(
|
||||||
|
session,
|
||||||
audio_path,
|
audio_path,
|
||||||
audio_url,
|
audio_url,
|
||||||
transcript.source_language,
|
transcript.source_language,
|
||||||
@@ -201,6 +203,7 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
async def run_parallel_processing(
|
async def run_parallel_processing(
|
||||||
self,
|
self,
|
||||||
|
session,
|
||||||
audio_path: Path,
|
audio_path: Path,
|
||||||
audio_url: str,
|
audio_url: str,
|
||||||
source_language: str,
|
source_language: str,
|
||||||
@@ -214,7 +217,7 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
# Phase 1: Parallel processing of independent tasks
|
# Phase 1: Parallel processing of independent tasks
|
||||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||||
diarization_task = self.diarize_file(audio_url)
|
diarization_task = self.diarize_file(audio_url)
|
||||||
waveform_task = self.generate_waveform(audio_path)
|
waveform_task = self.generate_waveform(session, audio_path)
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||||
@@ -262,7 +265,7 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
)
|
)
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
self.generate_title(topics),
|
self.generate_title(topics),
|
||||||
self.generate_summaries(topics),
|
self.generate_summaries(session, topics),
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -372,16 +375,13 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
await processor.flush()
|
await processor.flush()
|
||||||
|
|
||||||
async def generate_summaries(self, topics: list[TitleSummary]):
|
async def generate_summaries(self, session, topics: list[TitleSummary]):
|
||||||
"""Generate long and short summaries from topics"""
|
"""Generate long and short summaries from topics"""
|
||||||
if not topics:
|
if not topics:
|
||||||
self.logger.warning("No topics for summary generation")
|
self.logger.warning("No topics for summary generation")
|
||||||
return
|
return
|
||||||
|
|
||||||
async with get_session_factory()() as session:
|
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||||
transcript = await transcripts_controller.get_by_id(
|
|
||||||
session, self.transcript_id
|
|
||||||
)
|
|
||||||
processor = TranscriptFinalSummaryProcessor(
|
processor = TranscriptFinalSummaryProcessor(
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
callback=self.on_long_summary,
|
callback=self.on_long_summary,
|
||||||
@@ -397,35 +397,35 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def task_send_webhook_if_needed(*, transcript_id: str):
|
@with_session
|
||||||
|
async def task_send_webhook_if_needed(session, *, transcript_id: str):
|
||||||
"""Send webhook if this is a room recording with webhook configured"""
|
"""Send webhook if this is a room recording with webhook configured"""
|
||||||
async with get_session_factory()() as session:
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
if not transcript:
|
||||||
if not transcript:
|
return
|
||||||
return
|
|
||||||
|
|
||||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||||
if room and room.webhook_url:
|
if room and room.webhook_url:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Dispatching webhook",
|
"Dispatching webhook",
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
room_id=room.id,
|
room_id=room.id,
|
||||||
webhook_url=room.webhook_url,
|
webhook_url=room.webhook_url,
|
||||||
)
|
)
|
||||||
send_transcript_webhook.delay(
|
send_transcript_webhook.delay(
|
||||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def task_pipeline_file_process(*, transcript_id: str):
|
@with_session
|
||||||
|
async def task_pipeline_file_process(session, *, transcript_id: str):
|
||||||
"""Celery task for file pipeline processing"""
|
"""Celery task for file pipeline processing"""
|
||||||
async with get_session_factory()() as session:
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
if not transcript:
|
||||||
if not transcript:
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
|
||||||
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user