refactor: improve session management across worker tasks and pipelines

- Remove "if session" anti-pattern from all functions
- Functions now require explicit AsyncSession parameters instead of optional session_factory
- Worker tasks (Celery) create sessions at top level using session_factory
- Add proper AsyncSession type annotations to all session parameters
- Update cleanup.py: delete_single_transcript, cleanup_old_transcripts, cleanup_old_public_data
- Update process.py: process_recording, process_meetings, reprocess_failed_recordings
- Update ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings
- Update pipeline classes: get_transcript methods now require session
- Fix tests to pass sessions correctly

Benefits:
- Better type safety and IDE support with explicit AsyncSession typing
- Clear transaction boundaries with sessions created at task level
- Consistent session management pattern across codebase
- No ambiguity about session vs session_factory usage
This commit is contained in:
2025-09-23 08:39:50 -06:00
parent 60cc2b16ae
commit 617a1c8b32
7 changed files with 200 additions and 230 deletions

View File

@@ -87,15 +87,9 @@ class PipelineMainFile(PipelineMainBase):
self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger)
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
async def get_transcript(self, session: AsyncSession) -> Transcript:
"""Get transcript with session"""
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
result = await transcripts_controller.get_by_id(session, self.transcript_id)
if not result:
raise Exception("Transcript not found")
return result

View File

@@ -142,15 +142,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
self._ws_manager = get_ws_manager()
return self._ws_manager
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
async def get_transcript(self, session: AsyncSession) -> Transcript:
# fetch the transcript
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
result = await transcripts_controller.get_by_id(session, self.transcript_id)
if not result:
raise Exception("Transcript not found")
return result
@@ -349,7 +343,8 @@ class PipelineMainLive(PipelineMainBase):
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await self.get_transcript(session)
processors = [
AudioFileWriterProcessor(
@@ -397,7 +392,8 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await self.get_transcript(session)
# diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local":
@@ -430,7 +426,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
async def create(self) -> Pipeline:
# get transcript
self._transcript = transcript = await self.get_transcript()
async with get_session_factory()() as session:
self._transcript = transcript = await self.get_transcript(session)
# create pipeline
processors = self.get_processors()

View File

@@ -12,6 +12,7 @@ import structlog
from celery import shared_task
from pydantic.types import PositiveInt
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
@@ -33,104 +34,49 @@ class CleanupStats(TypedDict):
async def delete_single_transcript(
session_factory, transcript_data: dict, stats: CleanupStats, session=None
session: AsyncSession, transcript_data: dict, stats: CleanupStats
):
transcript_id = transcript_data["id"]
meeting_id = transcript_data["meeting_id"]
recording_id = transcript_data["recording_id"]
try:
if session:
# Use provided session for testing - don't start new transaction
if meeting_id:
await session.execute(
delete(MeetingModel).where(MeetingModel.id == meeting_id)
)
stats["meetings_deleted"] += 1
logger.info("Deleted associated meeting", meeting_id=meeting_id)
if recording_id:
result = await session.execute(
select(RecordingModel).where(RecordingModel.id == recording_id)
)
recording = result.mappings().first()
if recording:
try:
await get_recordings_storage().delete_file(
recording["object_key"]
)
except Exception as storage_error:
logger.warning(
"Failed to delete recording from storage",
recording_id=recording_id,
object_key=recording["object_key"],
error=str(storage_error),
)
await session.execute(
delete(RecordingModel).where(RecordingModel.id == recording_id)
)
stats["recordings_deleted"] += 1
logger.info(
"Deleted associated recording", recording_id=recording_id
)
await transcripts_controller.remove_by_id(session, transcript_id)
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
if meeting_id:
await session.execute(
delete(MeetingModel).where(MeetingModel.id == meeting_id)
)
else:
# Use session factory for production
async with session_factory() as session:
async with session.begin():
if meeting_id:
await session.execute(
delete(MeetingModel).where(MeetingModel.id == meeting_id)
)
stats["meetings_deleted"] += 1
logger.info("Deleted associated meeting", meeting_id=meeting_id)
stats["meetings_deleted"] += 1
logger.info("Deleted associated meeting", meeting_id=meeting_id)
if recording_id:
result = await session.execute(
select(RecordingModel).where(
RecordingModel.id == recording_id
)
)
recording = result.mappings().first()
if recording:
try:
await get_recordings_storage().delete_file(
recording["object_key"]
)
except Exception as storage_error:
logger.warning(
"Failed to delete recording from storage",
recording_id=recording_id,
object_key=recording["object_key"],
error=str(storage_error),
)
if recording_id:
result = await session.execute(
select(RecordingModel).where(RecordingModel.id == recording_id)
)
recording = result.mappings().first()
if recording:
try:
await get_recordings_storage().delete_file(recording["object_key"])
except Exception as storage_error:
logger.warning(
"Failed to delete recording from storage",
recording_id=recording_id,
object_key=recording["object_key"],
error=str(storage_error),
)
await session.execute(
delete(RecordingModel).where(
RecordingModel.id == recording_id
)
)
stats["recordings_deleted"] += 1
logger.info(
"Deleted associated recording",
recording_id=recording_id,
)
await transcripts_controller.remove_by_id(session, transcript_id)
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
await session.execute(
delete(RecordingModel).where(RecordingModel.id == recording_id)
)
stats["recordings_deleted"] += 1
logger.info("Deleted associated recording", recording_id=recording_id)
await transcripts_controller.remove_by_id(session, transcript_id)
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
)
except Exception as e:
error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}"
logger.error(error_msg, exc_info=e)
@@ -138,7 +84,7 @@ async def delete_single_transcript(
async def cleanup_old_transcripts(
session_factory, cutoff_date: datetime, stats: CleanupStats, session=None
session: AsyncSession, cutoff_date: datetime, stats: CleanupStats
):
"""Delete old anonymous transcripts and their associated recordings/meetings."""
query = select(
@@ -150,23 +96,14 @@ async def cleanup_old_transcripts(
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
)
if session:
# Use provided session for testing
result = await session.execute(query)
old_transcripts = result.mappings().all()
else:
# Use session factory for production
async with session_factory() as session:
result = await session.execute(query)
old_transcripts = result.mappings().all()
result = await session.execute(query)
old_transcripts = result.mappings().all()
logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
for transcript_data in old_transcripts:
try:
await delete_single_transcript(
session_factory, transcript_data, stats, session
)
await delete_single_transcript(session, transcript_data, stats)
except Exception as e:
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
logger.error(error_msg, exc_info=e)
@@ -190,8 +127,8 @@ def log_cleanup_results(stats: CleanupStats):
async def cleanup_old_public_data(
session: AsyncSession,
days: PositiveInt | None = None,
session=None,
) -> CleanupStats | None:
if days is None:
days = settings.PUBLIC_DATA_RETENTION_DAYS
@@ -213,8 +150,7 @@ async def cleanup_old_public_data(
"errors": [],
}
session_factory = get_session_factory()
await cleanup_old_transcripts(session_factory, cutoff_date, stats, session)
await cleanup_old_transcripts(session, cutoff_date, stats)
log_cleanup_results(stats)
return stats
@@ -226,4 +162,7 @@ async def cleanup_old_public_data(
)
@asynctask
async def cleanup_old_public_data_task(days: int | None = None):
await cleanup_old_public_data(days=days)
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
await cleanup_old_public_data(session, days=days)

View File

@@ -3,8 +3,10 @@ from datetime import datetime, timedelta, timezone
import structlog
from celery import shared_task
from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
@@ -19,7 +21,9 @@ logger = structlog.wrap_logger(get_task_logger(__name__))
@asynctask
async def sync_room_ics(room_id: str):
try:
room = await rooms_controller.get_by_id(room_id)
session_factory = get_session_factory()
async with session_factory() as session:
room = await rooms_controller.get_by_id(session, room_id)
if not room:
logger.warning("Room not found for ICS sync", room_id=room_id)
return
@@ -59,7 +63,9 @@ async def sync_all_ics_calendars():
try:
logger.info("Starting sync for all ICS-enabled rooms")
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
session_factory = get_session_factory()
async with session_factory() as session:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
for room in ics_enabled_rooms:
@@ -86,10 +92,14 @@ def _should_sync(room) -> bool:
MEETING_DEFAULT_DURATION = timedelta(hours=1)
async def create_upcoming_meetings_for_event(event, create_window, room_id, room):
async def create_upcoming_meetings_for_event(
session: AsyncSession, event, create_window, room_id, room
):
if event.start_time <= create_window:
return
existing_meeting = await meetings_controller.get_by_calendar_event(event.id)
existing_meeting = await meetings_controller.get_by_calendar_event(
session, event.id
)
if existing_meeting:
return
@@ -112,6 +122,7 @@ async def create_upcoming_meetings_for_event(event, create_window, room_id, room
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
meeting = await meetings_controller.create(
session,
id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"],
@@ -155,20 +166,24 @@ async def create_upcoming_meetings():
try:
logger.info("Starting creation of upcoming meetings")
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)
for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming(
room.id,
minutes_ahead=7,
)
for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming(
session,
room.id,
minutes_ahead=7,
)
for event in events:
await create_upcoming_meetings_for_event(
event, create_window, room.id, room
)
for event in events:
await create_upcoming_meetings_for_event(
session, event, create_window, room.id, room
)
logger.info("Completed pre-creation check for upcoming meetings")
except Exception as e:

View File

@@ -11,6 +11,7 @@ from celery.utils.log import get_task_logger
from pydantic import ValidationError
from redis.exceptions import LockError
from reflector.db import get_session_factory
from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
@@ -82,66 +83,78 @@ async def process_recording(bucket_name: str, object_key: str):
room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(room_name)
room = await rooms_controller.get_by_id(meeting.room_id)
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meeting = await meetings_controller.get_by_room_name(session, room_name)
room = await rooms_controller.get_by_id(session, meeting.room_id)
recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
if not recording:
recording = await recordings_controller.create(
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
)
if not recording:
recording = await recordings_controller.create(
session,
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
),
)
transcript = await transcripts_controller.get_by_recording_id(recording.id)
if transcript:
await transcripts_controller.update(
transcript,
{
"topics": [],
},
)
else:
transcript = await transcripts_controller.add(
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
if transcript:
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
else:
transcript = await transcripts_controller.add(
session,
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
_, extension = os.path.splitext(object_key)
upload_filename = transcript.data_path / f"upload{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True)
_, extension = os.path.splitext(object_key)
upload_filename = transcript.data_path / f"upload{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True)
s3 = boto3.client(
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
s3 = boto3.client(
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
with open(upload_filename, "wb") as f:
s3.download_fileobj(bucket_name, object_key, f)
with open(upload_filename, "wb") as f:
s3.download_fileobj(bucket_name, object_key, f)
container = av.open(upload_filename.as_posix())
try:
if not len(container.streams.audio):
raise Exception("File has no audio stream")
except Exception:
upload_filename.unlink()
raise
finally:
container.close()
container = av.open(upload_filename.as_posix())
try:
if not len(container.streams.audio):
raise Exception("File has no audio stream")
except Exception:
upload_filename.unlink()
raise
finally:
container.close()
await transcripts_controller.update(transcript, {"status": "uploaded"})
await transcripts_controller.update(
session, transcript, {"status": "uploaded"}
)
task_pipeline_file_process.delay(transcript_id=transcript.id)
@@ -165,7 +178,10 @@ async def process_meetings():
process the same meeting simultaneously.
"""
logger.info("Processing meetings")
meetings = await meetings_controller.get_all_active()
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meetings = await meetings_controller.get_all_active(session)
current_time = datetime.now(timezone.utc)
redis_client = get_redis_client()
processed_count = 0
@@ -218,7 +234,9 @@ async def process_meetings():
logger_.debug("Meeting not yet started, keep it")
if should_deactivate:
await meetings_controller.update_meeting(meeting.id, is_active=False)
await meetings_controller.update_meeting(
session, meeting.id, is_active=False
)
logger_.info("Meeting is deactivated")
processed_count += 1
@@ -260,40 +278,44 @@ async def reprocess_failed_recordings():
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name)
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
object_key = obj["Key"]
if not (object_key.endswith(".mp4")):
session_factory = get_session_factory()
async with session_factory() as session:
for page in pages:
if "Contents" not in page:
continue
recording = await recordings_controller.get_by_object_key(
bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
for obj in page["Contents"]:
object_key = obj["Key"]
transcript = None
try:
transcript = await transcripts_controller.get_by_recording_id(
recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(recording.id)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if not (object_key.endswith(".mp4")):
continue
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None
try:
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e:
logger.error(f"Error checking S3 bucket: {str(e)}")

View File

@@ -15,12 +15,12 @@ from reflector.worker.cleanup import cleanup_old_public_data
@pytest.mark.asyncio
async def test_cleanup_old_public_data_skips_when_not_public():
async def test_cleanup_old_public_data_skips_when_not_public(session):
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = False
result = await cleanup_old_public_data()
result = await cleanup_old_public_data(session)
# Should return early without doing anything
assert result is None
@@ -81,7 +81,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
mock_delete.return_value = None
# Run cleanup with test session
await cleanup_old_public_data(session=session)
await cleanup_old_public_data(session)
# Verify only old anonymous transcript was deleted
assert mock_delete.call_count == 1
@@ -162,7 +162,7 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
mock_storage.return_value.delete_file = AsyncMock()
# Run cleanup with test session
await cleanup_old_public_data(session=session)
await cleanup_old_public_data(session)
# Verify transcript was deleted
result = await session.execute(
@@ -226,7 +226,7 @@ async def test_cleanup_handles_errors_gracefully(session):
mock_delete.side_effect = [Exception("Delete failed"), None]
# Run cleanup with test session - should not raise exception
await cleanup_old_public_data(session=session)
await cleanup_old_public_data(session)
# Both transcripts should have been attempted to delete
assert mock_delete.call_count == 2

View File

@@ -624,7 +624,10 @@ async def test_pipeline_file_process_no_transcript():
# Should raise an exception for missing transcript when get_transcript is called
with pytest.raises(Exception, match="Transcript not found"):
await pipeline.get_transcript()
from reflector.db import get_session_factory
async with get_session_factory()() as session:
await pipeline.get_transcript(session)
@pytest.mark.asyncio