From f51dae8da36a644808cdd97b63eaa62796c6d6b2 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 23 Sep 2025 17:01:09 -0600 Subject: [PATCH] refactor: create @with_session_and_transcript decorator to simplify pipeline functions - Add new @with_session_and_transcript decorator that provides both session and transcript - Replace @get_transcript decorator with session-aware version in key pipeline functions - Remove duplicate get_session_factory() calls from cleanup_consent, pipeline_upload_mp3, and pipeline_post_to_zulip - Update task wrappers to use the new decorator pattern This eliminates redundant session creation and provides a cleaner, more consistent pattern for functions that need both database session and transcript access. --- .../reflector/pipelines/main_live_pipeline.py | 94 +++++++++---------- server/reflector/worker/session_decorator.py | 68 ++++++++++++++ 2 files changed, 113 insertions(+), 49 deletions(-) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 00cc47d8..5f2647e7 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -64,6 +64,7 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.storage import get_transcripts_storage +from reflector.worker.session_decorator import with_session_and_transcript from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.zulip import ( get_zulip_message, @@ -532,8 +533,7 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): logger.info("Convert to mp3 done") -@get_transcript -async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): +async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger): if not settings.TRANSCRIPT_STORAGE_BACKEND: logger.info("No storage backend configured, skipping mp3 upload") return @@ -551,8 +551,7 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): return # Upload to external storage and delete the file - async with get_session_factory()() as session: - await transcripts_controller.move_mp3_to_storage(session, transcript) + await transcripts_controller.move_mp3_to_storage(session, transcript) logger.info("Upload mp3 done") @@ -581,28 +580,24 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): logger.info("Summaries done") -@get_transcript -async def cleanup_consent(transcript: Transcript, logger: Logger): +async def cleanup_consent(session, transcript: Transcript, logger: Logger): logger.info("Starting consent cleanup") consent_denied = False recording = None try: if transcript.recording_id: - async with get_session_factory()() as session: - recording = await recordings_controller.get_by_id( - session, transcript.recording_id + recording = await recordings_controller.get_by_id( + session, transcript.recording_id + ) + if recording and recording.meeting_id: + meeting = await meetings_controller.get_by_id( + session, recording.meeting_id ) - if recording and recording.meeting_id: - meeting = await meetings_controller.get_by_id( - session, recording.meeting_id + if meeting: + consent_denied = await meeting_consent_controller.has_any_denial( + session, meeting.id ) - if meeting: - consent_denied = ( - await meeting_consent_controller.has_any_denial( - session, meeting.id - ) - ) except Exception as e: logger.error(f"Failed to get fetch consent: {e}", exc_info=e) consent_denied = True @@ -630,10 +625,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger): logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e) # non-transactional, files marked for deletion not actually deleted is possible - async with get_session_factory()() as session: - await transcripts_controller.update( - session, transcript, {"audio_deleted": True} - ) + await transcripts_controller.update(session, transcript, {"audio_deleted": True}) # 2. Delete processed audio from transcript storage S3 bucket if transcript.audio_location == "storage": storage = get_transcripts_storage() @@ -657,32 +649,28 @@ async def cleanup_consent(transcript: Transcript, logger: Logger): logger.info("Consent cleanup done") -@get_transcript -async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger): +async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger): logger.info("Starting post to zulip") if not transcript.recording_id: logger.info("Transcript has no recording") return - async with get_session_factory()() as session: - recording = await recordings_controller.get_by_id( - session, transcript.recording_id - ) - if not recording: - logger.info("Recording not found") - return + recording = await recordings_controller.get_by_id(session, transcript.recording_id) + if not recording: + logger.info("Recording not found") + return - if not recording.meeting_id: - logger.info("Recording has no meeting") - return + if not recording.meeting_id: + logger.info("Recording has no meeting") + return - meeting = await meetings_controller.get_by_id(session, recording.meeting_id) - if not meeting: - logger.info("No meeting found for this recording") - return + meeting = await meetings_controller.get_by_id(session, recording.meeting_id) + if not meeting: + logger.info("No meeting found for this recording") + return - room = await rooms_controller.get_by_id(session, meeting.room_id) + room = await rooms_controller.get_by_id(session, meeting.room_id) if not room: logger.error(f"Missing room for a meeting {meeting.id}") return @@ -707,10 +695,9 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger): response = await send_message_to_zulip( room.zulip_stream, room.zulip_topic, message ) - async with get_session_factory()() as session: - await transcripts_controller.update( - session, transcript, {"zulip_message_id": response["id"]} - ) + await transcripts_controller.update( + session, transcript, {"zulip_message_id": response["id"]} + ) logger.info("Posted to zulip") @@ -740,8 +727,11 @@ async def task_pipeline_convert_to_mp3(*, transcript_id: str): @shared_task @asynctask -async def task_pipeline_upload_mp3(*, transcript_id: str): - await pipeline_upload_mp3(transcript_id=transcript_id) +@with_session_and_transcript +async def task_pipeline_upload_mp3( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await pipeline_upload_mp3(session, transcript=transcript, logger=logger) @shared_task @@ -764,14 +754,20 @@ async def task_pipeline_final_summaries(*, transcript_id: str): @shared_task @asynctask -async def task_cleanup_consent(*, transcript_id: str): - await cleanup_consent(transcript_id=transcript_id) +@with_session_and_transcript +async def task_cleanup_consent( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await cleanup_consent(session, transcript=transcript, logger=logger) @shared_task @asynctask -async def task_pipeline_post_to_zulip(*, transcript_id: str): - await pipeline_post_to_zulip(transcript_id=transcript_id) +@with_session_and_transcript +async def task_pipeline_post_to_zulip( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await pipeline_post_to_zulip(session, transcript=transcript, logger=logger) def pipeline_post(*, transcript_id: str): diff --git a/server/reflector/worker/session_decorator.py b/server/reflector/worker/session_decorator.py index e01b3104..70d47b75 100644 --- a/server/reflector/worker/session_decorator.py +++ b/server/reflector/worker/session_decorator.py @@ -8,7 +8,11 @@ that stays open for the entire duration of the task execution. import functools from typing import Any, Callable, TypeVar +from celery import current_task + from reflector.db import get_session_factory +from reflector.db.transcripts import transcripts_controller +from reflector.logger import logger F = TypeVar("F", bound=Callable[..., Any]) @@ -39,3 +43,67 @@ def with_session(func: F) -> F: return await func(session, *args, **kwargs) return wrapper + + +def with_session_and_transcript(func: F) -> F: + """ + Decorator that provides both an AsyncSession and a Transcript to the decorated function. + + This decorator: + 1. Extracts transcript_id from kwargs + 2. Creates and manages a database session + 3. Fetches the transcript using the session + 4. Creates an enhanced logger with Celery task context + 5. Passes session, transcript, and logger to the decorated function + + This should be used AFTER the @asynctask decorator on Celery tasks. + + Example: + @shared_task + @asynctask + @with_session_and_transcript + async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str): + # session, transcript, and logger are automatically provided + room = await rooms_controller.get_by_id(session, transcript.room_id) + ... + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + transcript_id = kwargs.pop("transcript_id", None) + if not transcript_id: + raise ValueError( + "transcript_id is required for @with_session_and_transcript" + ) + + session_factory = get_session_factory() + async with session_factory() as session: + async with session.begin(): + # Fetch the transcript + transcript = await transcripts_controller.get_by_id( + session, transcript_id + ) + if not transcript: + raise Exception(f"Transcript {transcript_id} not found") + + # Create enhanced logger with Celery task context + tlogger = logger.bind(transcript_id=transcript.id) + if current_task: + tlogger = tlogger.bind( + task_id=current_task.request.id, + task_name=current_task.name, + worker_hostname=current_task.request.hostname, + task_retries=current_task.request.retries, + transcript_id=transcript_id, + ) + + try: + # Pass session, transcript, and logger to the decorated function + return await func( + session, transcript=transcript, logger=tlogger, *args, **kwargs + ) + except Exception: + tlogger.exception("Error in task execution") + raise + + return wrapper