diff --git a/server/reflector/worker/cleanup.py b/server/reflector/worker/cleanup.py index 3b1c4b55..f55f3acd 100644 --- a/server/reflector/worker/cleanup.py +++ b/server/reflector/worker/cleanup.py @@ -15,11 +15,11 @@ from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from reflector.asynctask import asynctask -from reflector.db import get_session_factory from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel from reflector.db.transcripts import transcripts_controller from reflector.settings import settings from reflector.storage import get_recordings_storage +from reflector.worker.session_decorator import with_session logger = structlog.get_logger(__name__) @@ -161,8 +161,6 @@ async def cleanup_old_public_data( retry_kwargs={"max_retries": 3, "countdown": 300}, ) @asynctask -async def cleanup_old_public_data_task(days: int | None = None): - session_factory = get_session_factory() - async with session_factory() as session: - async with session.begin(): - await cleanup_old_public_data(session, days=days) +@with_session +async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None): + await cleanup_old_public_data(session, days=days) diff --git a/server/reflector/worker/ics_sync.py b/server/reflector/worker/ics_sync.py index a03369f2..2794e3b6 100644 --- a/server/reflector/worker/ics_sync.py +++ b/server/reflector/worker/ics_sync.py @@ -6,24 +6,23 @@ from celery.utils.log import get_task_logger from sqlalchemy.ext.asyncio import AsyncSession from reflector.asynctask import asynctask -from reflector.db import get_session_factory from reflector.db.calendar_events import calendar_events_controller from reflector.db.meetings import meetings_controller from reflector.db.rooms import rooms_controller from reflector.redis_cache import RedisAsyncLock from reflector.services.ics_sync import SyncStatus, ics_sync_service from reflector.whereby import create_meeting, upload_logo +from reflector.worker.session_decorator import with_session logger = structlog.wrap_logger(get_task_logger(__name__)) @shared_task @asynctask -async def sync_room_ics(room_id: str): +@with_session +async def sync_room_ics(session: AsyncSession, room_id: str): try: - session_factory = get_session_factory() - async with session_factory() as session: - room = await rooms_controller.get_by_id(session, room_id) + room = await rooms_controller.get_by_id(session, room_id) if not room: logger.warning("Room not found for ICS sync", room_id=room_id) return @@ -59,13 +58,12 @@ async def sync_room_ics(room_id: str): @shared_task @asynctask -async def sync_all_ics_calendars(): +@with_session +async def sync_all_ics_calendars(session: AsyncSession): try: logger.info("Starting sync for all ICS-enabled rooms") - session_factory = get_session_factory() - async with session_factory() as session: - ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) + ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled") for room in ics_enabled_rooms: @@ -155,7 +153,8 @@ async def create_upcoming_meetings_for_event( @shared_task @asynctask -async def create_upcoming_meetings(): +@with_session +async def create_upcoming_meetings(session: AsyncSession): async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock: if not lock.acquired: logger.warning( @@ -166,24 +165,21 @@ async def create_upcoming_meetings(): try: logger.info("Starting creation of upcoming meetings") - session_factory = get_session_factory() - async with session_factory() as session: - async with session.begin(): - ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) - now = datetime.now(timezone.utc) - create_window = now - timedelta(minutes=6) + ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) + now = datetime.now(timezone.utc) + create_window = now - timedelta(minutes=6) - for room in ics_enabled_rooms: - events = await calendar_events_controller.get_upcoming( - session, - room.id, - minutes_ahead=7, - ) + for room in ics_enabled_rooms: + events = await calendar_events_controller.get_upcoming( + session, + room.id, + minutes_ahead=7, + ) - for event in events: - await create_upcoming_meetings_for_event( - session, event, create_window, room.id, room - ) + for event in events: + await create_upcoming_meetings_for_event( + session, event, create_window, room.id, room + ) logger.info("Completed pre-creation check for upcoming meetings") except Exception as e: diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index db988390..7284b5e8 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -10,8 +10,8 @@ from celery import shared_task from celery.utils.log import get_task_logger from pydantic import ValidationError from redis.exceptions import LockError +from sqlalchemy.ext.asyncio import AsyncSession -from reflector.db import get_session_factory from reflector.db.meetings import meetings_controller from reflector.db.recordings import Recording, recordings_controller from reflector.db.rooms import rooms_controller @@ -21,6 +21,7 @@ from reflector.pipelines.main_live_pipeline import asynctask from reflector.redis_cache import get_redis_client from reflector.settings import settings from reflector.whereby import get_room_sessions +from reflector.worker.session_decorator import with_session logger = structlog.wrap_logger(get_task_logger(__name__)) @@ -76,92 +77,91 @@ def process_messages(): @shared_task @asynctask -async def process_recording(bucket_name: str, object_key: str): +@with_session +async def process_recording(session: AsyncSession, bucket_name: str, object_key: str): logger.info("Processing recording: %s/%s", bucket_name, object_key) # extract a guid and a datetime from the object key room_name = f"/{object_key[:36]}" recorded_at = parse_datetime_with_timezone(object_key[37:57]) - session_factory = get_session_factory() - async with session_factory() as session: - async with session.begin(): - meeting = await meetings_controller.get_by_room_name(session, room_name) - room = await rooms_controller.get_by_id(session, meeting.room_id) + meeting = await meetings_controller.get_by_room_name(session, room_name) + if not meeting: + logger.warning("Room not found, may be deleted ?", room_name=room_name) + return - recording = await recordings_controller.get_by_object_key( - session, bucket_name, object_key - ) - if not recording: - recording = await recordings_controller.create( - session, - Recording( - bucket_name=bucket_name, - object_key=object_key, - recorded_at=recorded_at, - meeting_id=meeting.id, - ), - ) + room = await rooms_controller.get_by_id(session, meeting.room_id) - transcript = await transcripts_controller.get_by_recording_id( - session, recording.id - ) - if transcript: - await transcripts_controller.update( - session, - transcript, - { - "topics": [], - }, - ) - else: - transcript = await transcripts_controller.add( - session, - "", - source_kind=SourceKind.ROOM, - source_language="en", - target_language="en", - user_id=room.user_id, - recording_id=recording.id, - share_mode="public", - meeting_id=meeting.id, - room_id=room.id, - ) + recording = await recordings_controller.get_by_object_key( + session, bucket_name, object_key + ) + if not recording: + recording = await recordings_controller.create( + session, + Recording( + bucket_name=bucket_name, + object_key=object_key, + recorded_at=recorded_at, + meeting_id=meeting.id, + ), + ) - _, extension = os.path.splitext(object_key) - upload_filename = transcript.data_path / f"upload{extension}" - upload_filename.parent.mkdir(parents=True, exist_ok=True) + transcript = await transcripts_controller.get_by_recording_id(session, recording.id) + if transcript: + await transcripts_controller.update( + session, + transcript, + { + "topics": [], + }, + ) + else: + transcript = await transcripts_controller.add( + session, + "", + source_kind=SourceKind.ROOM, + source_language="en", + target_language="en", + user_id=room.user_id, + recording_id=recording.id, + share_mode="public", + meeting_id=meeting.id, + room_id=room.id, + ) - s3 = boto3.client( - "s3", - region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, - aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, - ) + _, extension = os.path.splitext(object_key) + upload_filename = transcript.data_path / f"upload{extension}" + upload_filename.parent.mkdir(parents=True, exist_ok=True) - with open(upload_filename, "wb") as f: - s3.download_fileobj(bucket_name, object_key, f) + s3 = boto3.client( + "s3", + region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, + aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, + ) - container = av.open(upload_filename.as_posix()) - try: - if not len(container.streams.audio): - raise Exception("File has no audio stream") - except Exception: - upload_filename.unlink() - raise - finally: - container.close() + with open(upload_filename, "wb") as f: + s3.download_fileobj(bucket_name, object_key, f) - await transcripts_controller.update( - session, transcript, {"status": "uploaded"} - ) + container = av.open(upload_filename.as_posix()) + try: + if not len(container.streams.audio): + raise Exception("File has no audio stream") + except Exception: + upload_filename.unlink() + raise + finally: + container.close() + + await transcripts_controller.update(session, transcript, {"status": "uploaded"}) task_pipeline_file_process.delay(transcript_id=transcript.id) @shared_task @asynctask -async def process_meetings(): +@with_session +async def process_meetings(session: AsyncSession): """ Checks which meetings are still active and deactivates those that have ended. @@ -178,10 +178,7 @@ async def process_meetings(): process the same meeting simultaneously. """ logger.info("Processing meetings") - session_factory = get_session_factory() - async with session_factory() as session: - async with session.begin(): - meetings = await meetings_controller.get_all_active(session) + meetings = await meetings_controller.get_all_active(session) current_time = datetime.now(timezone.utc) redis_client = get_redis_client() processed_count = 0 @@ -258,7 +255,8 @@ async def process_meetings(): @shared_task @asynctask -async def reprocess_failed_recordings(): +@with_session +async def reprocess_failed_recordings(session: AsyncSession): """ Find recordings in the S3 bucket and check if they have proper transcriptions. If not, requeue them for processing. @@ -278,44 +276,42 @@ async def reprocess_failed_recordings(): bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME pages = paginator.paginate(Bucket=bucket_name) - session_factory = get_session_factory() - async with session_factory() as session: - for page in pages: - if "Contents" not in page: + for page in pages: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + object_key = obj["Key"] + + if not (object_key.endswith(".mp4")): continue - for obj in page["Contents"]: - object_key = obj["Key"] + recording = await recordings_controller.get_by_object_key( + session, bucket_name, object_key + ) + if not recording: + logger.info(f"Queueing recording for processing: {object_key}") + process_recording.delay(bucket_name, object_key) + reprocessed_count += 1 + continue - if not (object_key.endswith(".mp4")): - continue - - recording = await recordings_controller.get_by_object_key( - session, bucket_name, object_key + transcript = None + try: + transcript = await transcripts_controller.get_by_recording_id( + session, recording.id + ) + except ValidationError: + await transcripts_controller.remove_by_recording_id( + session, recording.id + ) + logger.warning( + f"Removed invalid transcript for recording: {recording.id}" ) - if not recording: - logger.info(f"Queueing recording for processing: {object_key}") - process_recording.delay(bucket_name, object_key) - reprocessed_count += 1 - continue - transcript = None - try: - transcript = await transcripts_controller.get_by_recording_id( - session, recording.id - ) - except ValidationError: - await transcripts_controller.remove_by_recording_id( - session, recording.id - ) - logger.warning( - f"Removed invalid transcript for recording: {recording.id}" - ) - - if transcript is None or transcript.status == "error": - logger.info(f"Queueing recording for processing: {object_key}") - process_recording.delay(bucket_name, object_key) - reprocessed_count += 1 + if transcript is None or transcript.status == "error": + logger.info(f"Queueing recording for processing: {object_key}") + process_recording.delay(bucket_name, object_key) + reprocessed_count += 1 except Exception as e: logger.error(f"Error checking S3 bucket: {str(e)}") diff --git a/server/reflector/worker/session_decorator.py b/server/reflector/worker/session_decorator.py new file mode 100644 index 00000000..e01b3104 --- /dev/null +++ b/server/reflector/worker/session_decorator.py @@ -0,0 +1,41 @@ +""" +Session management decorator for async worker tasks. + +This decorator ensures that all worker tasks have a properly managed database session +that stays open for the entire duration of the task execution. +""" + +import functools +from typing import Any, Callable, TypeVar + +from reflector.db import get_session_factory + +F = TypeVar("F", bound=Callable[..., Any]) + + +def with_session(func: F) -> F: + """ + Decorator that provides an AsyncSession as the first argument to the decorated function. + + This should be used AFTER the @asynctask decorator on Celery tasks to ensure + proper session management throughout the task execution. + + Example: + @shared_task + @asynctask + @with_session + async def my_task(session: AsyncSession, arg1: str, arg2: int): + # session is automatically provided and managed + result = await some_controller.get_by_id(session, arg1) + ... + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + session_factory = get_session_factory() + async with session_factory() as session: + async with session.begin(): + # Pass session as first argument to the decorated function + return await func(session, *args, **kwargs) + + return wrapper