mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
refactor: create @with_session_and_transcript decorator to simplify pipeline functions
- Add new @with_session_and_transcript decorator that provides both session and transcript - Replace @get_transcript decorator with session-aware version in key pipeline functions - Remove duplicate get_session_factory() calls from cleanup_consent, pipeline_upload_mp3, and pipeline_post_to_zulip - Update task wrappers to use the new decorator pattern This eliminates redundant session creation and provides a cleaner, more consistent pattern for functions that need both database session and transcript access.
This commit is contained in:
@@ -64,6 +64,7 @@ from reflector.processors.types import (
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.worker.session_decorator import with_session_and_transcript
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
from reflector.zulip import (
|
||||
get_zulip_message,
|
||||
@@ -532,8 +533,7 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Convert to mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||
logger.info("No storage backend configured, skipping mp3 upload")
|
||||
return
|
||||
@@ -551,8 +551,7 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
@@ -581,28 +580,24 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
logger.info("Summaries done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting consent cleanup")
|
||||
|
||||
consent_denied = False
|
||||
recording = None
|
||||
try:
|
||||
if transcript.recording_id:
|
||||
async with get_session_factory()() as session:
|
||||
recording = await recordings_controller.get_by_id(
|
||||
session, transcript.recording_id
|
||||
recording = await recordings_controller.get_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if recording and recording.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(
|
||||
session, recording.meeting_id
|
||||
)
|
||||
if recording and recording.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(
|
||||
session, recording.meeting_id
|
||||
if meeting:
|
||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||
session, meeting.id
|
||||
)
|
||||
if meeting:
|
||||
consent_denied = (
|
||||
await meeting_consent_controller.has_any_denial(
|
||||
session, meeting.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||
consent_denied = True
|
||||
@@ -630,10 +625,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||
|
||||
# non-transactional, files marked for deletion not actually deleted is possible
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"audio_deleted": True}
|
||||
)
|
||||
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
|
||||
# 2. Delete processed audio from transcript storage S3 bucket
|
||||
if transcript.audio_location == "storage":
|
||||
storage = get_transcripts_storage()
|
||||
@@ -657,32 +649,28 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
logger.info("Consent cleanup done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting post to zulip")
|
||||
|
||||
if not transcript.recording_id:
|
||||
logger.info("Transcript has no recording")
|
||||
return
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
recording = await recordings_controller.get_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if not recording:
|
||||
logger.info("Recording not found")
|
||||
return
|
||||
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
|
||||
if not recording:
|
||||
logger.info("Recording not found")
|
||||
return
|
||||
|
||||
if not recording.meeting_id:
|
||||
logger.info("Recording has no meeting")
|
||||
return
|
||||
if not recording.meeting_id:
|
||||
logger.info("Recording has no meeting")
|
||||
return
|
||||
|
||||
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
||||
if not meeting:
|
||||
logger.info("No meeting found for this recording")
|
||||
return
|
||||
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
||||
if not meeting:
|
||||
logger.info("No meeting found for this recording")
|
||||
return
|
||||
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
if not room:
|
||||
logger.error(f"Missing room for a meeting {meeting.id}")
|
||||
return
|
||||
@@ -707,10 +695,9 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
response = await send_message_to_zulip(
|
||||
room.zulip_stream, room.zulip_topic, message
|
||||
)
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"zulip_message_id": response["id"]}
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"zulip_message_id": response["id"]}
|
||||
)
|
||||
|
||||
logger.info("Posted to zulip")
|
||||
|
||||
@@ -740,8 +727,11 @@ async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_upload_mp3(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@@ -764,14 +754,20 @@ async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_cleanup_consent(*, transcript_id: str):
|
||||
await cleanup_consent(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_cleanup_consent(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
||||
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_post_to_zulip(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
def pipeline_post(*, transcript_id: str):
|
||||
|
||||
@@ -8,7 +8,11 @@ that stays open for the entire duration of the task execution.
|
||||
import functools
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from celery import current_task
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.logger import logger
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
@@ -39,3 +43,67 @@ def with_session(func: F) -> F:
|
||||
return await func(session, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_session_and_transcript(func: F) -> F:
|
||||
"""
|
||||
Decorator that provides both an AsyncSession and a Transcript to the decorated function.
|
||||
|
||||
This decorator:
|
||||
1. Extracts transcript_id from kwargs
|
||||
2. Creates and manages a database session
|
||||
3. Fetches the transcript using the session
|
||||
4. Creates an enhanced logger with Celery task context
|
||||
5. Passes session, transcript, and logger to the decorated function
|
||||
|
||||
This should be used AFTER the @asynctask decorator on Celery tasks.
|
||||
|
||||
Example:
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session_and_transcript
|
||||
async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str):
|
||||
# session, transcript, and logger are automatically provided
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
...
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id", None)
|
||||
if not transcript_id:
|
||||
raise ValueError(
|
||||
"transcript_id is required for @with_session_and_transcript"
|
||||
)
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
# Fetch the transcript
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, transcript_id
|
||||
)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
# Create enhanced logger with Celery task context
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
if current_task:
|
||||
tlogger = tlogger.bind(
|
||||
task_id=current_task.request.id,
|
||||
task_name=current_task.name,
|
||||
worker_hostname=current_task.request.hostname,
|
||||
task_retries=current_task.request.retries,
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Pass session, transcript, and logger to the decorated function
|
||||
return await func(
|
||||
session, transcript=transcript, logger=tlogger, *args, **kwargs
|
||||
)
|
||||
except Exception:
|
||||
tlogger.exception("Error in task execution")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user