refactor: create @with_session_and_transcript decorator to simplify pipeline functions

- Add new @with_session_and_transcript decorator that provides both session and transcript
- Replace @get_transcript decorator with session-aware version in key pipeline functions
- Remove duplicate get_session_factory() calls from cleanup_consent, pipeline_upload_mp3, and pipeline_post_to_zulip
- Update task wrappers to use the new decorator pattern

This eliminates redundant session creation and provides a cleaner, more consistent
pattern for functions that need both database session and transcript access.
This commit is contained in:
2025-09-23 17:01:09 -06:00
parent b217c7ba41
commit f51dae8da3
2 changed files with 113 additions and 49 deletions

View File

@@ -64,6 +64,7 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.storage import get_transcripts_storage
from reflector.worker.session_decorator import with_session_and_transcript
from reflector.ws_manager import WebsocketManager, get_ws_manager
from reflector.zulip import (
get_zulip_message,
@@ -532,8 +533,7 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload")
return
@@ -551,8 +551,7 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
return
# Upload to external storage and delete the file
async with get_session_factory()() as session:
await transcripts_controller.move_mp3_to_storage(session, transcript)
await transcripts_controller.move_mp3_to_storage(session, transcript)
logger.info("Upload mp3 done")
@@ -581,28 +580,24 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Summaries done")
@get_transcript
async def cleanup_consent(transcript: Transcript, logger: Logger):
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
logger.info("Starting consent cleanup")
consent_denied = False
recording = None
try:
if transcript.recording_id:
async with get_session_factory()() as session:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id(
session, recording.meeting_id
)
if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id(
session, recording.meeting_id
if meeting:
consent_denied = await meeting_consent_controller.has_any_denial(
session, meeting.id
)
if meeting:
consent_denied = (
await meeting_consent_controller.has_any_denial(
session, meeting.id
)
)
except Exception as e:
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
consent_denied = True
@@ -630,10 +625,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible
async with get_session_factory()() as session:
await transcripts_controller.update(
session, transcript, {"audio_deleted": True}
)
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
# 2. Delete processed audio from transcript storage S3 bucket
if transcript.audio_location == "storage":
storage = get_transcripts_storage()
@@ -657,32 +649,28 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
logger.info("Consent cleanup done")
@get_transcript
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
logger.info("Starting post to zulip")
if not transcript.recording_id:
logger.info("Transcript has no recording")
return
async with get_session_factory()() as session:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
if not recording:
logger.info("Recording not found")
return
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
if not recording:
logger.info("Recording not found")
return
if not recording.meeting_id:
logger.info("Recording has no meeting")
return
if not recording.meeting_id:
logger.info("Recording has no meeting")
return
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
if not meeting:
logger.info("No meeting found for this recording")
return
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
if not meeting:
logger.info("No meeting found for this recording")
return
room = await rooms_controller.get_by_id(session, meeting.room_id)
room = await rooms_controller.get_by_id(session, meeting.room_id)
if not room:
logger.error(f"Missing room for a meeting {meeting.id}")
return
@@ -707,10 +695,9 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
response = await send_message_to_zulip(
room.zulip_stream, room.zulip_topic, message
)
async with get_session_factory()() as session:
await transcripts_controller.update(
session, transcript, {"zulip_message_id": response["id"]}
)
await transcripts_controller.update(
session, transcript, {"zulip_message_id": response["id"]}
)
logger.info("Posted to zulip")
@@ -740,8 +727,11 @@ async def task_pipeline_convert_to_mp3(*, transcript_id: str):
@shared_task
@asynctask
async def task_pipeline_upload_mp3(*, transcript_id: str):
await pipeline_upload_mp3(transcript_id=transcript_id)
@with_session_and_transcript
async def task_pipeline_upload_mp3(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
@shared_task
@@ -764,14 +754,20 @@ async def task_pipeline_final_summaries(*, transcript_id: str):
@shared_task
@asynctask
async def task_cleanup_consent(*, transcript_id: str):
await cleanup_consent(transcript_id=transcript_id)
@with_session_and_transcript
async def task_cleanup_consent(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await cleanup_consent(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_post_to_zulip(*, transcript_id: str):
await pipeline_post_to_zulip(transcript_id=transcript_id)
@with_session_and_transcript
async def task_pipeline_post_to_zulip(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
def pipeline_post(*, transcript_id: str):

View File

@@ -8,7 +8,11 @@ that stays open for the entire duration of the task execution.
import functools
from typing import Any, Callable, TypeVar
from celery import current_task
from reflector.db import get_session_factory
from reflector.db.transcripts import transcripts_controller
from reflector.logger import logger
F = TypeVar("F", bound=Callable[..., Any])
@@ -39,3 +43,67 @@ def with_session(func: F) -> F:
return await func(session, *args, **kwargs)
return wrapper
def with_session_and_transcript(func: F) -> F:
"""
Decorator that provides both an AsyncSession and a Transcript to the decorated function.
This decorator:
1. Extracts transcript_id from kwargs
2. Creates and manages a database session
3. Fetches the transcript using the session
4. Creates an enhanced logger with Celery task context
5. Passes session, transcript, and logger to the decorated function
This should be used AFTER the @asynctask decorator on Celery tasks.
Example:
@shared_task
@asynctask
@with_session_and_transcript
async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str):
# session, transcript, and logger are automatically provided
room = await rooms_controller.get_by_id(session, transcript.room_id)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
transcript_id = kwargs.pop("transcript_id", None)
if not transcript_id:
raise ValueError(
"transcript_id is required for @with_session_and_transcript"
)
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
# Fetch the transcript
transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
# Create enhanced logger with Celery task context
tlogger = logger.bind(transcript_id=transcript.id)
if current_task:
tlogger = tlogger.bind(
task_id=current_task.request.id,
task_name=current_task.name,
worker_hostname=current_task.request.hostname,
task_retries=current_task.request.retries,
transcript_id=transcript_id,
)
try:
# Pass session, transcript, and logger to the decorated function
return await func(
session, transcript=transcript, logger=tlogger, *args, **kwargs
)
except Exception:
tlogger.exception("Error in task execution")
raise
return wrapper