mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
refactor: create @with_session_and_transcript decorator to simplify pipeline functions
- Add new @with_session_and_transcript decorator that provides both session and transcript - Replace @get_transcript decorator with session-aware version in key pipeline functions - Remove duplicate get_session_factory() calls from cleanup_consent, pipeline_upload_mp3, and pipeline_post_to_zulip - Update task wrappers to use the new decorator pattern This eliminates redundant session creation and provides a cleaner, more consistent pattern for functions that need both database session and transcript access.
This commit is contained in:
@@ -64,6 +64,7 @@ from reflector.processors.types import (
|
|||||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage import get_transcripts_storage
|
from reflector.storage import get_transcripts_storage
|
||||||
|
from reflector.worker.session_decorator import with_session_and_transcript
|
||||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||||
from reflector.zulip import (
|
from reflector.zulip import (
|
||||||
get_zulip_message,
|
get_zulip_message,
|
||||||
@@ -532,8 +533,7 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Convert to mp3 done")
|
logger.info("Convert to mp3 done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
|
||||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||||
logger.info("No storage backend configured, skipping mp3 upload")
|
logger.info("No storage backend configured, skipping mp3 upload")
|
||||||
return
|
return
|
||||||
@@ -551,7 +551,6 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Upload to external storage and delete the file
|
# Upload to external storage and delete the file
|
||||||
async with get_session_factory()() as session:
|
|
||||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
||||||
|
|
||||||
logger.info("Upload mp3 done")
|
logger.info("Upload mp3 done")
|
||||||
@@ -581,15 +580,13 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Summaries done")
|
logger.info("Summaries done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
||||||
async def cleanup_consent(transcript: Transcript, logger: Logger):
|
|
||||||
logger.info("Starting consent cleanup")
|
logger.info("Starting consent cleanup")
|
||||||
|
|
||||||
consent_denied = False
|
consent_denied = False
|
||||||
recording = None
|
recording = None
|
||||||
try:
|
try:
|
||||||
if transcript.recording_id:
|
if transcript.recording_id:
|
||||||
async with get_session_factory()() as session:
|
|
||||||
recording = await recordings_controller.get_by_id(
|
recording = await recordings_controller.get_by_id(
|
||||||
session, transcript.recording_id
|
session, transcript.recording_id
|
||||||
)
|
)
|
||||||
@@ -598,11 +595,9 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
|||||||
session, recording.meeting_id
|
session, recording.meeting_id
|
||||||
)
|
)
|
||||||
if meeting:
|
if meeting:
|
||||||
consent_denied = (
|
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||||
await meeting_consent_controller.has_any_denial(
|
|
||||||
session, meeting.id
|
session, meeting.id
|
||||||
)
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||||
consent_denied = True
|
consent_denied = True
|
||||||
@@ -630,10 +625,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
|||||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||||
|
|
||||||
# non-transactional, files marked for deletion not actually deleted is possible
|
# non-transactional, files marked for deletion not actually deleted is possible
|
||||||
async with get_session_factory()() as session:
|
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
|
||||||
await transcripts_controller.update(
|
|
||||||
session, transcript, {"audio_deleted": True}
|
|
||||||
)
|
|
||||||
# 2. Delete processed audio from transcript storage S3 bucket
|
# 2. Delete processed audio from transcript storage S3 bucket
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
storage = get_transcripts_storage()
|
storage = get_transcripts_storage()
|
||||||
@@ -657,18 +649,14 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Consent cleanup done")
|
logger.info("Consent cleanup done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
|
||||||
logger.info("Starting post to zulip")
|
logger.info("Starting post to zulip")
|
||||||
|
|
||||||
if not transcript.recording_id:
|
if not transcript.recording_id:
|
||||||
logger.info("Transcript has no recording")
|
logger.info("Transcript has no recording")
|
||||||
return
|
return
|
||||||
|
|
||||||
async with get_session_factory()() as session:
|
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
|
||||||
recording = await recordings_controller.get_by_id(
|
|
||||||
session, transcript.recording_id
|
|
||||||
)
|
|
||||||
if not recording:
|
if not recording:
|
||||||
logger.info("Recording not found")
|
logger.info("Recording not found")
|
||||||
return
|
return
|
||||||
@@ -707,7 +695,6 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
|||||||
response = await send_message_to_zulip(
|
response = await send_message_to_zulip(
|
||||||
room.zulip_stream, room.zulip_topic, message
|
room.zulip_stream, room.zulip_topic, message
|
||||||
)
|
)
|
||||||
async with get_session_factory()() as session:
|
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session, transcript, {"zulip_message_id": response["id"]}
|
session, transcript, {"zulip_message_id": response["id"]}
|
||||||
)
|
)
|
||||||
@@ -740,8 +727,11 @@ async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
await pipeline_upload_mp3(transcript_id=transcript_id)
|
async def task_pipeline_upload_mp3(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@@ -764,14 +754,20 @@ async def task_pipeline_final_summaries(*, transcript_id: str):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def task_cleanup_consent(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
await cleanup_consent(transcript_id=transcript_id)
|
async def task_cleanup_consent(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
async def task_pipeline_post_to_zulip(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_post(*, transcript_id: str):
|
def pipeline_post(*, transcript_id: str):
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ that stays open for the entire duration of the task execution.
|
|||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, TypeVar
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
from celery import current_task
|
||||||
|
|
||||||
from reflector.db import get_session_factory
|
from reflector.db import get_session_factory
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
from reflector.logger import logger
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
@@ -39,3 +43,67 @@ def with_session(func: F) -> F:
|
|||||||
return await func(session, *args, **kwargs)
|
return await func(session, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def with_session_and_transcript(func: F) -> F:
|
||||||
|
"""
|
||||||
|
Decorator that provides both an AsyncSession and a Transcript to the decorated function.
|
||||||
|
|
||||||
|
This decorator:
|
||||||
|
1. Extracts transcript_id from kwargs
|
||||||
|
2. Creates and manages a database session
|
||||||
|
3. Fetches the transcript using the session
|
||||||
|
4. Creates an enhanced logger with Celery task context
|
||||||
|
5. Passes session, transcript, and logger to the decorated function
|
||||||
|
|
||||||
|
This should be used AFTER the @asynctask decorator on Celery tasks.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
@with_session_and_transcript
|
||||||
|
async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str):
|
||||||
|
# session, transcript, and logger are automatically provided
|
||||||
|
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
transcript_id = kwargs.pop("transcript_id", None)
|
||||||
|
if not transcript_id:
|
||||||
|
raise ValueError(
|
||||||
|
"transcript_id is required for @with_session_and_transcript"
|
||||||
|
)
|
||||||
|
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
|
async with session.begin():
|
||||||
|
# Fetch the transcript
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, transcript_id
|
||||||
|
)
|
||||||
|
if not transcript:
|
||||||
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
|
|
||||||
|
# Create enhanced logger with Celery task context
|
||||||
|
tlogger = logger.bind(transcript_id=transcript.id)
|
||||||
|
if current_task:
|
||||||
|
tlogger = tlogger.bind(
|
||||||
|
task_id=current_task.request.id,
|
||||||
|
task_name=current_task.name,
|
||||||
|
worker_hostname=current_task.request.hostname,
|
||||||
|
task_retries=current_task.request.retries,
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Pass session, transcript, and logger to the decorated function
|
||||||
|
return await func(
|
||||||
|
session, transcript=transcript, logger=tlogger, *args, **kwargs
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
tlogger.exception("Error in task execution")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
Reference in New Issue
Block a user