refactor: improve session management across worker tasks and pipelines

- Remove "if session" anti-pattern from all functions
- Functions now require explicit AsyncSession parameters instead of optional session_factory
- Worker tasks (Celery) create sessions at top level using session_factory
- Add proper AsyncSession type annotations to all session parameters
- Update cleanup.py: delete_single_transcript, cleanup_old_transcripts, cleanup_old_public_data
- Update process.py: process_recording, process_meetings, reprocess_failed_recordings
- Update ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings
- Update pipeline classes: get_transcript methods now require session
- Fix tests to pass sessions correctly

Benefits:
- Better type safety and IDE support with explicit AsyncSession typing
- Clear transaction boundaries with sessions created at task level
- Consistent session management pattern across codebase
- No ambiguity about session vs session_factory usage
This commit is contained in:
2025-09-23 08:39:50 -06:00
parent 60cc2b16ae
commit 617a1c8b32
7 changed files with 200 additions and 230 deletions

View File

@@ -87,15 +87,9 @@ class PipelineMainFile(PipelineMainBase):
self.logger = logger.bind(transcript_id=self.transcript_id) self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger) self.empty_pipeline = EmptyPipeline(logger=self.logger)
async def get_transcript(self, session: AsyncSession = None) -> Transcript: async def get_transcript(self, session: AsyncSession) -> Transcript:
"""Get transcript with session""" """Get transcript with session"""
if session: result = await transcripts_controller.get_by_id(session, self.transcript_id)
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result: if not result:
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result

View File

@@ -142,15 +142,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
self._ws_manager = get_ws_manager() self._ws_manager = get_ws_manager()
return self._ws_manager return self._ws_manager
async def get_transcript(self, session: AsyncSession = None) -> Transcript: async def get_transcript(self, session: AsyncSession) -> Transcript:
# fetch the transcript # fetch the transcript
if session: result = await transcripts_controller.get_by_id(session, self.transcript_id)
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result: if not result:
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
@@ -349,7 +343,8 @@ class PipelineMainLive(PipelineMainBase):
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await self.get_transcript(session)
processors = [ processors = [
AudioFileWriterProcessor( AudioFileWriterProcessor(
@@ -397,7 +392,8 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
# now let's start the pipeline by pushing information to the # now let's start the pipeline by pushing information to the
# first processor diarization processor # first processor diarization processor
# XXX translation is lost when converting our data model to the processor model # XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await self.get_transcript(session)
# diarization works only if the file is uploaded to an external storage # diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local": if transcript.audio_location == "local":
@@ -430,7 +426,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# get transcript # get transcript
self._transcript = transcript = await self.get_transcript() async with get_session_factory()() as session:
self._transcript = transcript = await self.get_transcript(session)
# create pipeline # create pipeline
processors = self.get_processors() processors = self.get_processors()

View File

@@ -12,6 +12,7 @@ import structlog
from celery import shared_task from celery import shared_task
from pydantic.types import PositiveInt from pydantic.types import PositiveInt
from sqlalchemy import delete, select from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory from reflector.db import get_session_factory
@@ -33,104 +34,49 @@ class CleanupStats(TypedDict):
async def delete_single_transcript( async def delete_single_transcript(
session_factory, transcript_data: dict, stats: CleanupStats, session=None session: AsyncSession, transcript_data: dict, stats: CleanupStats
): ):
transcript_id = transcript_data["id"] transcript_id = transcript_data["id"]
meeting_id = transcript_data["meeting_id"] meeting_id = transcript_data["meeting_id"]
recording_id = transcript_data["recording_id"] recording_id = transcript_data["recording_id"]
try: try:
if session: if meeting_id:
# Use provided session for testing - don't start new transaction await session.execute(
if meeting_id: delete(MeetingModel).where(MeetingModel.id == meeting_id)
await session.execute(
delete(MeetingModel).where(MeetingModel.id == meeting_id)
)
stats["meetings_deleted"] += 1
logger.info("Deleted associated meeting", meeting_id=meeting_id)
if recording_id:
result = await session.execute(
select(RecordingModel).where(RecordingModel.id == recording_id)
)
recording = result.mappings().first()
if recording:
try:
await get_recordings_storage().delete_file(
recording["object_key"]
)
except Exception as storage_error:
logger.warning(
"Failed to delete recording from storage",
recording_id=recording_id,
object_key=recording["object_key"],
error=str(storage_error),
)
await session.execute(
delete(RecordingModel).where(RecordingModel.id == recording_id)
)
stats["recordings_deleted"] += 1
logger.info(
"Deleted associated recording", recording_id=recording_id
)
await transcripts_controller.remove_by_id(session, transcript_id)
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
) )
else: stats["meetings_deleted"] += 1
# Use session factory for production logger.info("Deleted associated meeting", meeting_id=meeting_id)
async with session_factory() as session:
async with session.begin():
if meeting_id:
await session.execute(
delete(MeetingModel).where(MeetingModel.id == meeting_id)
)
stats["meetings_deleted"] += 1
logger.info("Deleted associated meeting", meeting_id=meeting_id)
if recording_id: if recording_id:
result = await session.execute( result = await session.execute(
select(RecordingModel).where( select(RecordingModel).where(RecordingModel.id == recording_id)
RecordingModel.id == recording_id )
) recording = result.mappings().first()
) if recording:
recording = result.mappings().first() try:
if recording: await get_recordings_storage().delete_file(recording["object_key"])
try: except Exception as storage_error:
await get_recordings_storage().delete_file( logger.warning(
recording["object_key"] "Failed to delete recording from storage",
) recording_id=recording_id,
except Exception as storage_error: object_key=recording["object_key"],
logger.warning( error=str(storage_error),
"Failed to delete recording from storage", )
recording_id=recording_id,
object_key=recording["object_key"],
error=str(storage_error),
)
await session.execute( await session.execute(
delete(RecordingModel).where( delete(RecordingModel).where(RecordingModel.id == recording_id)
RecordingModel.id == recording_id
)
)
stats["recordings_deleted"] += 1
logger.info(
"Deleted associated recording",
recording_id=recording_id,
)
await transcripts_controller.remove_by_id(session, transcript_id)
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
) )
stats["recordings_deleted"] += 1
logger.info("Deleted associated recording", recording_id=recording_id)
await transcripts_controller.remove_by_id(session, transcript_id)
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
)
except Exception as e: except Exception as e:
error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}" error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}"
logger.error(error_msg, exc_info=e) logger.error(error_msg, exc_info=e)
@@ -138,7 +84,7 @@ async def delete_single_transcript(
async def cleanup_old_transcripts( async def cleanup_old_transcripts(
session_factory, cutoff_date: datetime, stats: CleanupStats, session=None session: AsyncSession, cutoff_date: datetime, stats: CleanupStats
): ):
"""Delete old anonymous transcripts and their associated recordings/meetings.""" """Delete old anonymous transcripts and their associated recordings/meetings."""
query = select( query = select(
@@ -150,23 +96,14 @@ async def cleanup_old_transcripts(
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None)) (TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
) )
if session: result = await session.execute(query)
# Use provided session for testing old_transcripts = result.mappings().all()
result = await session.execute(query)
old_transcripts = result.mappings().all()
else:
# Use session factory for production
async with session_factory() as session:
result = await session.execute(query)
old_transcripts = result.mappings().all()
logger.info(f"Found {len(old_transcripts)} old transcripts to delete") logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
for transcript_data in old_transcripts: for transcript_data in old_transcripts:
try: try:
await delete_single_transcript( await delete_single_transcript(session, transcript_data, stats)
session_factory, transcript_data, stats, session
)
except Exception as e: except Exception as e:
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}" error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
logger.error(error_msg, exc_info=e) logger.error(error_msg, exc_info=e)
@@ -190,8 +127,8 @@ def log_cleanup_results(stats: CleanupStats):
async def cleanup_old_public_data( async def cleanup_old_public_data(
session: AsyncSession,
days: PositiveInt | None = None, days: PositiveInt | None = None,
session=None,
) -> CleanupStats | None: ) -> CleanupStats | None:
if days is None: if days is None:
days = settings.PUBLIC_DATA_RETENTION_DAYS days = settings.PUBLIC_DATA_RETENTION_DAYS
@@ -213,8 +150,7 @@ async def cleanup_old_public_data(
"errors": [], "errors": [],
} }
session_factory = get_session_factory() await cleanup_old_transcripts(session, cutoff_date, stats)
await cleanup_old_transcripts(session_factory, cutoff_date, stats, session)
log_cleanup_results(stats) log_cleanup_results(stats)
return stats return stats
@@ -226,4 +162,7 @@ async def cleanup_old_public_data(
) )
@asynctask @asynctask
async def cleanup_old_public_data_task(days: int | None = None): async def cleanup_old_public_data_task(days: int | None = None):
await cleanup_old_public_data(days=days) session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
await cleanup_old_public_data(session, days=days)

View File

@@ -3,8 +3,10 @@ from datetime import datetime, timedelta, timezone
import structlog import structlog
from celery import shared_task from celery import shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -19,7 +21,9 @@ logger = structlog.wrap_logger(get_task_logger(__name__))
@asynctask @asynctask
async def sync_room_ics(room_id: str): async def sync_room_ics(room_id: str):
try: try:
room = await rooms_controller.get_by_id(room_id) session_factory = get_session_factory()
async with session_factory() as session:
room = await rooms_controller.get_by_id(session, room_id)
if not room: if not room:
logger.warning("Room not found for ICS sync", room_id=room_id) logger.warning("Room not found for ICS sync", room_id=room_id)
return return
@@ -59,7 +63,9 @@ async def sync_all_ics_calendars():
try: try:
logger.info("Starting sync for all ICS-enabled rooms") logger.info("Starting sync for all ICS-enabled rooms")
ics_enabled_rooms = await rooms_controller.get_ics_enabled() session_factory = get_session_factory()
async with session_factory() as session:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled") logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
for room in ics_enabled_rooms: for room in ics_enabled_rooms:
@@ -86,10 +92,14 @@ def _should_sync(room) -> bool:
MEETING_DEFAULT_DURATION = timedelta(hours=1) MEETING_DEFAULT_DURATION = timedelta(hours=1)
async def create_upcoming_meetings_for_event(event, create_window, room_id, room): async def create_upcoming_meetings_for_event(
session: AsyncSession, event, create_window, room_id, room
):
if event.start_time <= create_window: if event.start_time <= create_window:
return return
existing_meeting = await meetings_controller.get_by_calendar_event(event.id) existing_meeting = await meetings_controller.get_by_calendar_event(
session, event.id
)
if existing_meeting: if existing_meeting:
return return
@@ -112,6 +122,7 @@ async def create_upcoming_meetings_for_event(event, create_window, room_id, room
await upload_logo(whereby_meeting["roomName"], "./images/logo.png") await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
meeting = await meetings_controller.create( meeting = await meetings_controller.create(
session,
id=whereby_meeting["meetingId"], id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"], room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"], room_url=whereby_meeting["roomUrl"],
@@ -155,20 +166,24 @@ async def create_upcoming_meetings():
try: try:
logger.info("Starting creation of upcoming meetings") logger.info("Starting creation of upcoming meetings")
ics_enabled_rooms = await rooms_controller.get_ics_enabled() session_factory = get_session_factory()
now = datetime.now(timezone.utc) async with session_factory() as session:
create_window = now - timedelta(minutes=6) async with session.begin():
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)
for room in ics_enabled_rooms: for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming( events = await calendar_events_controller.get_upcoming(
room.id, session,
minutes_ahead=7, room.id,
) minutes_ahead=7,
)
for event in events: for event in events:
await create_upcoming_meetings_for_event( await create_upcoming_meetings_for_event(
event, create_window, room.id, room session, event, create_window, room.id, room
) )
logger.info("Completed pre-creation check for upcoming meetings") logger.info("Completed pre-creation check for upcoming meetings")
except Exception as e: except Exception as e:

View File

@@ -11,6 +11,7 @@ from celery.utils.log import get_task_logger
from pydantic import ValidationError from pydantic import ValidationError
from redis.exceptions import LockError from redis.exceptions import LockError
from reflector.db import get_session_factory
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -82,66 +83,78 @@ async def process_recording(bucket_name: str, object_key: str):
room_name = f"/{object_key[:36]}" room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57]) recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(room_name) session_factory = get_session_factory()
room = await rooms_controller.get_by_id(meeting.room_id) async with session_factory() as session:
async with session.begin():
meeting = await meetings_controller.get_by_room_name(session, room_name)
room = await rooms_controller.get_by_id(session, meeting.room_id)
recording = await recordings_controller.get_by_object_key(bucket_name, object_key) recording = await recordings_controller.get_by_object_key(
if not recording: session, bucket_name, object_key
recording = await recordings_controller.create(
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
) )
) if not recording:
recording = await recordings_controller.create(
session,
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
),
)
transcript = await transcripts_controller.get_by_recording_id(recording.id) transcript = await transcripts_controller.get_by_recording_id(
if transcript: session, recording.id
await transcripts_controller.update( )
transcript, if transcript:
{ await transcripts_controller.update(
"topics": [], session,
}, transcript,
) {
else: "topics": [],
transcript = await transcripts_controller.add( },
"", )
source_kind=SourceKind.ROOM, else:
source_language="en", transcript = await transcripts_controller.add(
target_language="en", session,
user_id=room.user_id, "",
recording_id=recording.id, source_kind=SourceKind.ROOM,
share_mode="public", source_language="en",
meeting_id=meeting.id, target_language="en",
room_id=room.id, user_id=room.user_id,
) recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
_, extension = os.path.splitext(object_key) _, extension = os.path.splitext(object_key)
upload_filename = transcript.data_path / f"upload{extension}" upload_filename = transcript.data_path / f"upload{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True) upload_filename.parent.mkdir(parents=True, exist_ok=True)
s3 = boto3.client( s3 = boto3.client(
"s3", "s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
) )
with open(upload_filename, "wb") as f: with open(upload_filename, "wb") as f:
s3.download_fileobj(bucket_name, object_key, f) s3.download_fileobj(bucket_name, object_key, f)
container = av.open(upload_filename.as_posix()) container = av.open(upload_filename.as_posix())
try: try:
if not len(container.streams.audio): if not len(container.streams.audio):
raise Exception("File has no audio stream") raise Exception("File has no audio stream")
except Exception: except Exception:
upload_filename.unlink() upload_filename.unlink()
raise raise
finally: finally:
container.close() container.close()
await transcripts_controller.update(transcript, {"status": "uploaded"}) await transcripts_controller.update(
session, transcript, {"status": "uploaded"}
)
task_pipeline_file_process.delay(transcript_id=transcript.id) task_pipeline_file_process.delay(transcript_id=transcript.id)
@@ -165,7 +178,10 @@ async def process_meetings():
process the same meeting simultaneously. process the same meeting simultaneously.
""" """
logger.info("Processing meetings") logger.info("Processing meetings")
meetings = await meetings_controller.get_all_active() session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meetings = await meetings_controller.get_all_active(session)
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
redis_client = get_redis_client() redis_client = get_redis_client()
processed_count = 0 processed_count = 0
@@ -218,7 +234,9 @@ async def process_meetings():
logger_.debug("Meeting not yet started, keep it") logger_.debug("Meeting not yet started, keep it")
if should_deactivate: if should_deactivate:
await meetings_controller.update_meeting(meeting.id, is_active=False) await meetings_controller.update_meeting(
session, meeting.id, is_active=False
)
logger_.info("Meeting is deactivated") logger_.info("Meeting is deactivated")
processed_count += 1 processed_count += 1
@@ -260,40 +278,44 @@ async def reprocess_failed_recordings():
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name) pages = paginator.paginate(Bucket=bucket_name)
for page in pages: session_factory = get_session_factory()
if "Contents" not in page: async with session_factory() as session:
continue for page in pages:
if "Contents" not in page:
for obj in page["Contents"]:
object_key = obj["Key"]
if not (object_key.endswith(".mp4")):
continue continue
recording = await recordings_controller.get_by_object_key( for obj in page["Contents"]:
bucket_name, object_key object_key = obj["Key"]
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None if not (object_key.endswith(".mp4")):
try: continue
transcript = await transcripts_controller.get_by_recording_id(
recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(recording.id)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error": recording = await recordings_controller.get_by_object_key(
logger.info(f"Queueing recording for processing: {object_key}") session, bucket_name, object_key
process_recording.delay(bucket_name, object_key) )
reprocessed_count += 1 if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None
try:
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e: except Exception as e:
logger.error(f"Error checking S3 bucket: {str(e)}") logger.error(f"Error checking S3 bucket: {str(e)}")

View File

@@ -15,12 +15,12 @@ from reflector.worker.cleanup import cleanup_old_public_data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_old_public_data_skips_when_not_public(): async def test_cleanup_old_public_data_skips_when_not_public(session):
"""Test that cleanup is skipped when PUBLIC_MODE is False.""" """Test that cleanup is skipped when PUBLIC_MODE is False."""
with patch("reflector.worker.cleanup.settings") as mock_settings: with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = False mock_settings.PUBLIC_MODE = False
result = await cleanup_old_public_data() result = await cleanup_old_public_data(session)
# Should return early without doing anything # Should return early without doing anything
assert result is None assert result is None
@@ -81,7 +81,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
mock_delete.return_value = None mock_delete.return_value = None
# Run cleanup with test session # Run cleanup with test session
await cleanup_old_public_data(session=session) await cleanup_old_public_data(session)
# Verify only old anonymous transcript was deleted # Verify only old anonymous transcript was deleted
assert mock_delete.call_count == 1 assert mock_delete.call_count == 1
@@ -162,7 +162,7 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
mock_storage.return_value.delete_file = AsyncMock() mock_storage.return_value.delete_file = AsyncMock()
# Run cleanup with test session # Run cleanup with test session
await cleanup_old_public_data(session=session) await cleanup_old_public_data(session)
# Verify transcript was deleted # Verify transcript was deleted
result = await session.execute( result = await session.execute(
@@ -226,7 +226,7 @@ async def test_cleanup_handles_errors_gracefully(session):
mock_delete.side_effect = [Exception("Delete failed"), None] mock_delete.side_effect = [Exception("Delete failed"), None]
# Run cleanup with test session - should not raise exception # Run cleanup with test session - should not raise exception
await cleanup_old_public_data(session=session) await cleanup_old_public_data(session)
# Both transcripts should have been attempted to delete # Both transcripts should have been attempted to delete
assert mock_delete.call_count == 2 assert mock_delete.call_count == 2

View File

@@ -624,7 +624,10 @@ async def test_pipeline_file_process_no_transcript():
# Should raise an exception for missing transcript when get_transcript is called # Should raise an exception for missing transcript when get_transcript is called
with pytest.raises(Exception, match="Transcript not found"): with pytest.raises(Exception, match="Transcript not found"):
await pipeline.get_transcript() from reflector.db import get_session_factory
async with get_session_factory()() as session:
await pipeline.get_transcript(session)
@pytest.mark.asyncio @pytest.mark.asyncio