diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index 8d644e19..fb1f925c 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -87,15 +87,9 @@ class PipelineMainFile(PipelineMainBase): self.logger = logger.bind(transcript_id=self.transcript_id) self.empty_pipeline = EmptyPipeline(logger=self.logger) - async def get_transcript(self, session: AsyncSession = None) -> Transcript: + async def get_transcript(self, session: AsyncSession) -> Transcript: """Get transcript with session""" - if session: - result = await transcripts_controller.get_by_id(session, self.transcript_id) - else: - async with get_session_factory()() as session: - result = await transcripts_controller.get_by_id( - session, self.transcript_id - ) + result = await transcripts_controller.get_by_id(session, self.transcript_id) if not result: raise Exception("Transcript not found") return result diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index ecba1e9f..00cc47d8 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -142,15 +142,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] self._ws_manager = get_ws_manager() return self._ws_manager - async def get_transcript(self, session: AsyncSession = None) -> Transcript: + async def get_transcript(self, session: AsyncSession) -> Transcript: # fetch the transcript - if session: - result = await transcripts_controller.get_by_id(session, self.transcript_id) - else: - async with get_session_factory()() as session: - result = await transcripts_controller.get_by_id( - session, self.transcript_id - ) + result = await transcripts_controller.get_by_id(session, self.transcript_id) if not result: raise Exception("Transcript not found") return result @@ -349,7 +343,8 @@ class PipelineMainLive(PipelineMainBase): async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context - transcript = await self.get_transcript() + async with get_session_factory()() as session: + transcript = await self.get_transcript(session) processors = [ AudioFileWriterProcessor( @@ -397,7 +392,8 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]): # now let's start the pipeline by pushing information to the # first processor diarization processor # XXX translation is lost when converting our data model to the processor model - transcript = await self.get_transcript() + async with get_session_factory()() as session: + transcript = await self.get_transcript(session) # diarization works only if the file is uploaded to an external storage if transcript.audio_location == "local": @@ -430,7 +426,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]): async def create(self) -> Pipeline: # get transcript - self._transcript = transcript = await self.get_transcript() + async with get_session_factory()() as session: + self._transcript = transcript = await self.get_transcript(session) # create pipeline processors = self.get_processors() diff --git a/server/reflector/worker/cleanup.py b/server/reflector/worker/cleanup.py index 965a44f9..3b1c4b55 100644 --- a/server/reflector/worker/cleanup.py +++ b/server/reflector/worker/cleanup.py @@ -12,6 +12,7 @@ import structlog from celery import shared_task from pydantic.types import PositiveInt from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession from reflector.asynctask import asynctask from reflector.db import get_session_factory @@ -33,104 +34,49 @@ class CleanupStats(TypedDict): async def delete_single_transcript( - session_factory, transcript_data: dict, stats: CleanupStats, session=None + session: AsyncSession, transcript_data: dict, stats: CleanupStats ): transcript_id = transcript_data["id"] meeting_id = transcript_data["meeting_id"] recording_id = transcript_data["recording_id"] try: - if session: - # Use provided session for testing - don't start new transaction - if meeting_id: - await session.execute( - delete(MeetingModel).where(MeetingModel.id == meeting_id) - ) - stats["meetings_deleted"] += 1 - logger.info("Deleted associated meeting", meeting_id=meeting_id) - - if recording_id: - result = await session.execute( - select(RecordingModel).where(RecordingModel.id == recording_id) - ) - recording = result.mappings().first() - if recording: - try: - await get_recordings_storage().delete_file( - recording["object_key"] - ) - except Exception as storage_error: - logger.warning( - "Failed to delete recording from storage", - recording_id=recording_id, - object_key=recording["object_key"], - error=str(storage_error), - ) - - await session.execute( - delete(RecordingModel).where(RecordingModel.id == recording_id) - ) - stats["recordings_deleted"] += 1 - logger.info( - "Deleted associated recording", recording_id=recording_id - ) - - await transcripts_controller.remove_by_id(session, transcript_id) - stats["transcripts_deleted"] += 1 - logger.info( - "Deleted transcript", - transcript_id=transcript_id, - created_at=transcript_data["created_at"].isoformat(), + if meeting_id: + await session.execute( + delete(MeetingModel).where(MeetingModel.id == meeting_id) ) - else: - # Use session factory for production - async with session_factory() as session: - async with session.begin(): - if meeting_id: - await session.execute( - delete(MeetingModel).where(MeetingModel.id == meeting_id) - ) - stats["meetings_deleted"] += 1 - logger.info("Deleted associated meeting", meeting_id=meeting_id) + stats["meetings_deleted"] += 1 + logger.info("Deleted associated meeting", meeting_id=meeting_id) - if recording_id: - result = await session.execute( - select(RecordingModel).where( - RecordingModel.id == recording_id - ) - ) - recording = result.mappings().first() - if recording: - try: - await get_recordings_storage().delete_file( - recording["object_key"] - ) - except Exception as storage_error: - logger.warning( - "Failed to delete recording from storage", - recording_id=recording_id, - object_key=recording["object_key"], - error=str(storage_error), - ) + if recording_id: + result = await session.execute( + select(RecordingModel).where(RecordingModel.id == recording_id) + ) + recording = result.mappings().first() + if recording: + try: + await get_recordings_storage().delete_file(recording["object_key"]) + except Exception as storage_error: + logger.warning( + "Failed to delete recording from storage", + recording_id=recording_id, + object_key=recording["object_key"], + error=str(storage_error), + ) - await session.execute( - delete(RecordingModel).where( - RecordingModel.id == recording_id - ) - ) - stats["recordings_deleted"] += 1 - logger.info( - "Deleted associated recording", - recording_id=recording_id, - ) - - await transcripts_controller.remove_by_id(session, transcript_id) - stats["transcripts_deleted"] += 1 - logger.info( - "Deleted transcript", - transcript_id=transcript_id, - created_at=transcript_data["created_at"].isoformat(), + await session.execute( + delete(RecordingModel).where(RecordingModel.id == recording_id) ) + stats["recordings_deleted"] += 1 + logger.info("Deleted associated recording", recording_id=recording_id) + + await transcripts_controller.remove_by_id(session, transcript_id) + stats["transcripts_deleted"] += 1 + logger.info( + "Deleted transcript", + transcript_id=transcript_id, + created_at=transcript_data["created_at"].isoformat(), + ) except Exception as e: error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}" logger.error(error_msg, exc_info=e) @@ -138,7 +84,7 @@ async def delete_single_transcript( async def cleanup_old_transcripts( - session_factory, cutoff_date: datetime, stats: CleanupStats, session=None + session: AsyncSession, cutoff_date: datetime, stats: CleanupStats ): """Delete old anonymous transcripts and their associated recordings/meetings.""" query = select( @@ -150,23 +96,14 @@ async def cleanup_old_transcripts( (TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None)) ) - if session: - # Use provided session for testing - result = await session.execute(query) - old_transcripts = result.mappings().all() - else: - # Use session factory for production - async with session_factory() as session: - result = await session.execute(query) - old_transcripts = result.mappings().all() + result = await session.execute(query) + old_transcripts = result.mappings().all() logger.info(f"Found {len(old_transcripts)} old transcripts to delete") for transcript_data in old_transcripts: try: - await delete_single_transcript( - session_factory, transcript_data, stats, session - ) + await delete_single_transcript(session, transcript_data, stats) except Exception as e: error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}" logger.error(error_msg, exc_info=e) @@ -190,8 +127,8 @@ def log_cleanup_results(stats: CleanupStats): async def cleanup_old_public_data( + session: AsyncSession, days: PositiveInt | None = None, - session=None, ) -> CleanupStats | None: if days is None: days = settings.PUBLIC_DATA_RETENTION_DAYS @@ -213,8 +150,7 @@ async def cleanup_old_public_data( "errors": [], } - session_factory = get_session_factory() - await cleanup_old_transcripts(session_factory, cutoff_date, stats, session) + await cleanup_old_transcripts(session, cutoff_date, stats) log_cleanup_results(stats) return stats @@ -226,4 +162,7 @@ async def cleanup_old_public_data( ) @asynctask async def cleanup_old_public_data_task(days: int | None = None): - await cleanup_old_public_data(days=days) + session_factory = get_session_factory() + async with session_factory() as session: + async with session.begin(): + await cleanup_old_public_data(session, days=days) diff --git a/server/reflector/worker/ics_sync.py b/server/reflector/worker/ics_sync.py index faf62f4a..a03369f2 100644 --- a/server/reflector/worker/ics_sync.py +++ b/server/reflector/worker/ics_sync.py @@ -3,8 +3,10 @@ from datetime import datetime, timedelta, timezone import structlog from celery import shared_task from celery.utils.log import get_task_logger +from sqlalchemy.ext.asyncio import AsyncSession from reflector.asynctask import asynctask +from reflector.db import get_session_factory from reflector.db.calendar_events import calendar_events_controller from reflector.db.meetings import meetings_controller from reflector.db.rooms import rooms_controller @@ -19,7 +21,9 @@ logger = structlog.wrap_logger(get_task_logger(__name__)) @asynctask async def sync_room_ics(room_id: str): try: - room = await rooms_controller.get_by_id(room_id) + session_factory = get_session_factory() + async with session_factory() as session: + room = await rooms_controller.get_by_id(session, room_id) if not room: logger.warning("Room not found for ICS sync", room_id=room_id) return @@ -59,7 +63,9 @@ async def sync_all_ics_calendars(): try: logger.info("Starting sync for all ICS-enabled rooms") - ics_enabled_rooms = await rooms_controller.get_ics_enabled() + session_factory = get_session_factory() + async with session_factory() as session: + ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled") for room in ics_enabled_rooms: @@ -86,10 +92,14 @@ def _should_sync(room) -> bool: MEETING_DEFAULT_DURATION = timedelta(hours=1) -async def create_upcoming_meetings_for_event(event, create_window, room_id, room): +async def create_upcoming_meetings_for_event( + session: AsyncSession, event, create_window, room_id, room +): if event.start_time <= create_window: return - existing_meeting = await meetings_controller.get_by_calendar_event(event.id) + existing_meeting = await meetings_controller.get_by_calendar_event( + session, event.id + ) if existing_meeting: return @@ -112,6 +122,7 @@ async def create_upcoming_meetings_for_event(event, create_window, room_id, room await upload_logo(whereby_meeting["roomName"], "./images/logo.png") meeting = await meetings_controller.create( + session, id=whereby_meeting["meetingId"], room_name=whereby_meeting["roomName"], room_url=whereby_meeting["roomUrl"], @@ -155,20 +166,24 @@ async def create_upcoming_meetings(): try: logger.info("Starting creation of upcoming meetings") - ics_enabled_rooms = await rooms_controller.get_ics_enabled() - now = datetime.now(timezone.utc) - create_window = now - timedelta(minutes=6) + session_factory = get_session_factory() + async with session_factory() as session: + async with session.begin(): + ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) + now = datetime.now(timezone.utc) + create_window = now - timedelta(minutes=6) - for room in ics_enabled_rooms: - events = await calendar_events_controller.get_upcoming( - room.id, - minutes_ahead=7, - ) + for room in ics_enabled_rooms: + events = await calendar_events_controller.get_upcoming( + session, + room.id, + minutes_ahead=7, + ) - for event in events: - await create_upcoming_meetings_for_event( - event, create_window, room.id, room - ) + for event in events: + await create_upcoming_meetings_for_event( + session, event, create_window, room.id, room + ) logger.info("Completed pre-creation check for upcoming meetings") except Exception as e: diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index e660e840..db988390 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -11,6 +11,7 @@ from celery.utils.log import get_task_logger from pydantic import ValidationError from redis.exceptions import LockError +from reflector.db import get_session_factory from reflector.db.meetings import meetings_controller from reflector.db.recordings import Recording, recordings_controller from reflector.db.rooms import rooms_controller @@ -82,66 +83,78 @@ async def process_recording(bucket_name: str, object_key: str): room_name = f"/{object_key[:36]}" recorded_at = parse_datetime_with_timezone(object_key[37:57]) - meeting = await meetings_controller.get_by_room_name(room_name) - room = await rooms_controller.get_by_id(meeting.room_id) + session_factory = get_session_factory() + async with session_factory() as session: + async with session.begin(): + meeting = await meetings_controller.get_by_room_name(session, room_name) + room = await rooms_controller.get_by_id(session, meeting.room_id) - recording = await recordings_controller.get_by_object_key(bucket_name, object_key) - if not recording: - recording = await recordings_controller.create( - Recording( - bucket_name=bucket_name, - object_key=object_key, - recorded_at=recorded_at, - meeting_id=meeting.id, + recording = await recordings_controller.get_by_object_key( + session, bucket_name, object_key ) - ) + if not recording: + recording = await recordings_controller.create( + session, + Recording( + bucket_name=bucket_name, + object_key=object_key, + recorded_at=recorded_at, + meeting_id=meeting.id, + ), + ) - transcript = await transcripts_controller.get_by_recording_id(recording.id) - if transcript: - await transcripts_controller.update( - transcript, - { - "topics": [], - }, - ) - else: - transcript = await transcripts_controller.add( - "", - source_kind=SourceKind.ROOM, - source_language="en", - target_language="en", - user_id=room.user_id, - recording_id=recording.id, - share_mode="public", - meeting_id=meeting.id, - room_id=room.id, - ) + transcript = await transcripts_controller.get_by_recording_id( + session, recording.id + ) + if transcript: + await transcripts_controller.update( + session, + transcript, + { + "topics": [], + }, + ) + else: + transcript = await transcripts_controller.add( + session, + "", + source_kind=SourceKind.ROOM, + source_language="en", + target_language="en", + user_id=room.user_id, + recording_id=recording.id, + share_mode="public", + meeting_id=meeting.id, + room_id=room.id, + ) - _, extension = os.path.splitext(object_key) - upload_filename = transcript.data_path / f"upload{extension}" - upload_filename.parent.mkdir(parents=True, exist_ok=True) + _, extension = os.path.splitext(object_key) + upload_filename = transcript.data_path / f"upload{extension}" + upload_filename.parent.mkdir(parents=True, exist_ok=True) - s3 = boto3.client( - "s3", - region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, - aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, - ) + s3 = boto3.client( + "s3", + region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, + aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, + ) - with open(upload_filename, "wb") as f: - s3.download_fileobj(bucket_name, object_key, f) + with open(upload_filename, "wb") as f: + s3.download_fileobj(bucket_name, object_key, f) - container = av.open(upload_filename.as_posix()) - try: - if not len(container.streams.audio): - raise Exception("File has no audio stream") - except Exception: - upload_filename.unlink() - raise - finally: - container.close() + container = av.open(upload_filename.as_posix()) + try: + if not len(container.streams.audio): + raise Exception("File has no audio stream") + except Exception: + upload_filename.unlink() + raise + finally: + container.close() - await transcripts_controller.update(transcript, {"status": "uploaded"}) + await transcripts_controller.update( + session, transcript, {"status": "uploaded"} + ) task_pipeline_file_process.delay(transcript_id=transcript.id) @@ -165,7 +178,10 @@ async def process_meetings(): process the same meeting simultaneously. """ logger.info("Processing meetings") - meetings = await meetings_controller.get_all_active() + session_factory = get_session_factory() + async with session_factory() as session: + async with session.begin(): + meetings = await meetings_controller.get_all_active(session) current_time = datetime.now(timezone.utc) redis_client = get_redis_client() processed_count = 0 @@ -218,7 +234,9 @@ async def process_meetings(): logger_.debug("Meeting not yet started, keep it") if should_deactivate: - await meetings_controller.update_meeting(meeting.id, is_active=False) + await meetings_controller.update_meeting( + session, meeting.id, is_active=False + ) logger_.info("Meeting is deactivated") processed_count += 1 @@ -260,40 +278,44 @@ async def reprocess_failed_recordings(): bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME pages = paginator.paginate(Bucket=bucket_name) - for page in pages: - if "Contents" not in page: - continue - - for obj in page["Contents"]: - object_key = obj["Key"] - - if not (object_key.endswith(".mp4")): + session_factory = get_session_factory() + async with session_factory() as session: + for page in pages: + if "Contents" not in page: continue - recording = await recordings_controller.get_by_object_key( - bucket_name, object_key - ) - if not recording: - logger.info(f"Queueing recording for processing: {object_key}") - process_recording.delay(bucket_name, object_key) - reprocessed_count += 1 - continue + for obj in page["Contents"]: + object_key = obj["Key"] - transcript = None - try: - transcript = await transcripts_controller.get_by_recording_id( - recording.id - ) - except ValidationError: - await transcripts_controller.remove_by_recording_id(recording.id) - logger.warning( - f"Removed invalid transcript for recording: {recording.id}" - ) + if not (object_key.endswith(".mp4")): + continue - if transcript is None or transcript.status == "error": - logger.info(f"Queueing recording for processing: {object_key}") - process_recording.delay(bucket_name, object_key) - reprocessed_count += 1 + recording = await recordings_controller.get_by_object_key( + session, bucket_name, object_key + ) + if not recording: + logger.info(f"Queueing recording for processing: {object_key}") + process_recording.delay(bucket_name, object_key) + reprocessed_count += 1 + continue + + transcript = None + try: + transcript = await transcripts_controller.get_by_recording_id( + session, recording.id + ) + except ValidationError: + await transcripts_controller.remove_by_recording_id( + session, recording.id + ) + logger.warning( + f"Removed invalid transcript for recording: {recording.id}" + ) + + if transcript is None or transcript.status == "error": + logger.info(f"Queueing recording for processing: {object_key}") + process_recording.delay(bucket_name, object_key) + reprocessed_count += 1 except Exception as e: logger.error(f"Error checking S3 bucket: {str(e)}") diff --git a/server/tests/test_cleanup.py b/server/tests/test_cleanup.py index 9ccead68..a3626c32 100644 --- a/server/tests/test_cleanup.py +++ b/server/tests/test_cleanup.py @@ -15,12 +15,12 @@ from reflector.worker.cleanup import cleanup_old_public_data @pytest.mark.asyncio -async def test_cleanup_old_public_data_skips_when_not_public(): +async def test_cleanup_old_public_data_skips_when_not_public(session): """Test that cleanup is skipped when PUBLIC_MODE is False.""" with patch("reflector.worker.cleanup.settings") as mock_settings: mock_settings.PUBLIC_MODE = False - result = await cleanup_old_public_data() + result = await cleanup_old_public_data(session) # Should return early without doing anything assert result is None @@ -81,7 +81,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session mock_delete.return_value = None # Run cleanup with test session - await cleanup_old_public_data(session=session) + await cleanup_old_public_data(session) # Verify only old anonymous transcript was deleted assert mock_delete.call_count == 1 @@ -162,7 +162,7 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session): mock_storage.return_value.delete_file = AsyncMock() # Run cleanup with test session - await cleanup_old_public_data(session=session) + await cleanup_old_public_data(session) # Verify transcript was deleted result = await session.execute( @@ -226,7 +226,7 @@ async def test_cleanup_handles_errors_gracefully(session): mock_delete.side_effect = [Exception("Delete failed"), None] # Run cleanup with test session - should not raise exception - await cleanup_old_public_data(session=session) + await cleanup_old_public_data(session) # Both transcripts should have been attempted to delete assert mock_delete.call_count == 2 diff --git a/server/tests/test_pipeline_main_file.py b/server/tests/test_pipeline_main_file.py index 32a69f24..49c2d22c 100644 --- a/server/tests/test_pipeline_main_file.py +++ b/server/tests/test_pipeline_main_file.py @@ -624,7 +624,10 @@ async def test_pipeline_file_process_no_transcript(): # Should raise an exception for missing transcript when get_transcript is called with pytest.raises(Exception, match="Transcript not found"): - await pipeline.get_transcript() + from reflector.db import get_session_factory + + async with get_session_factory()() as session: + await pipeline.get_transcript(session) @pytest.mark.asyncio