mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
refactor: improve session management across worker tasks and pipelines
- Remove "if session" anti-pattern from all functions - Functions now require explicit AsyncSession parameters instead of optional session_factory - Worker tasks (Celery) create sessions at top level using session_factory - Add proper AsyncSession type annotations to all session parameters - Update cleanup.py: delete_single_transcript, cleanup_old_transcripts, cleanup_old_public_data - Update process.py: process_recording, process_meetings, reprocess_failed_recordings - Update ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings - Update pipeline classes: get_transcript methods now require session - Fix tests to pass sessions correctly Benefits: - Better type safety and IDE support with explicit AsyncSession typing - Clear transaction boundaries with sessions created at task level - Consistent session management pattern across codebase - No ambiguity about session vs session_factory usage
This commit is contained in:
@@ -87,15 +87,9 @@ class PipelineMainFile(PipelineMainBase):
|
||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||
|
||||
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
||||
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
||||
"""Get transcript with session"""
|
||||
if session:
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
else:
|
||||
async with get_session_factory()() as session:
|
||||
result = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
@@ -142,15 +142,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
self._ws_manager = get_ws_manager()
|
||||
return self._ws_manager
|
||||
|
||||
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
||||
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
||||
# fetch the transcript
|
||||
if session:
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
else:
|
||||
async with get_session_factory()() as session:
|
||||
result = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
@@ -349,7 +343,8 @@ class PipelineMainLive(PipelineMainBase):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(
|
||||
@@ -397,7 +392,8 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
# diarization works only if the file is uploaded to an external storage
|
||||
if transcript.audio_location == "local":
|
||||
@@ -430,7 +426,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
self._transcript = transcript = await self.get_transcript(session)
|
||||
|
||||
# create pipeline
|
||||
processors = self.get_processors()
|
||||
|
||||
@@ -12,6 +12,7 @@ import structlog
|
||||
from celery import shared_task
|
||||
from pydantic.types import PositiveInt
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
@@ -33,104 +34,49 @@ class CleanupStats(TypedDict):
|
||||
|
||||
|
||||
async def delete_single_transcript(
|
||||
session_factory, transcript_data: dict, stats: CleanupStats, session=None
|
||||
session: AsyncSession, transcript_data: dict, stats: CleanupStats
|
||||
):
|
||||
transcript_id = transcript_data["id"]
|
||||
meeting_id = transcript_data["meeting_id"]
|
||||
recording_id = transcript_data["recording_id"]
|
||||
|
||||
try:
|
||||
if session:
|
||||
# Use provided session for testing - don't start new transaction
|
||||
if meeting_id:
|
||||
await session.execute(
|
||||
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
||||
)
|
||||
stats["meetings_deleted"] += 1
|
||||
logger.info("Deleted associated meeting", meeting_id=meeting_id)
|
||||
|
||||
if recording_id:
|
||||
result = await session.execute(
|
||||
select(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
)
|
||||
recording = result.mappings().first()
|
||||
if recording:
|
||||
try:
|
||||
await get_recordings_storage().delete_file(
|
||||
recording["object_key"]
|
||||
)
|
||||
except Exception as storage_error:
|
||||
logger.warning(
|
||||
"Failed to delete recording from storage",
|
||||
recording_id=recording_id,
|
||||
object_key=recording["object_key"],
|
||||
error=str(storage_error),
|
||||
)
|
||||
|
||||
await session.execute(
|
||||
delete(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
)
|
||||
stats["recordings_deleted"] += 1
|
||||
logger.info(
|
||||
"Deleted associated recording", recording_id=recording_id
|
||||
)
|
||||
|
||||
await transcripts_controller.remove_by_id(session, transcript_id)
|
||||
stats["transcripts_deleted"] += 1
|
||||
logger.info(
|
||||
"Deleted transcript",
|
||||
transcript_id=transcript_id,
|
||||
created_at=transcript_data["created_at"].isoformat(),
|
||||
if meeting_id:
|
||||
await session.execute(
|
||||
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
||||
)
|
||||
else:
|
||||
# Use session factory for production
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
if meeting_id:
|
||||
await session.execute(
|
||||
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
||||
)
|
||||
stats["meetings_deleted"] += 1
|
||||
logger.info("Deleted associated meeting", meeting_id=meeting_id)
|
||||
stats["meetings_deleted"] += 1
|
||||
logger.info("Deleted associated meeting", meeting_id=meeting_id)
|
||||
|
||||
if recording_id:
|
||||
result = await session.execute(
|
||||
select(RecordingModel).where(
|
||||
RecordingModel.id == recording_id
|
||||
)
|
||||
)
|
||||
recording = result.mappings().first()
|
||||
if recording:
|
||||
try:
|
||||
await get_recordings_storage().delete_file(
|
||||
recording["object_key"]
|
||||
)
|
||||
except Exception as storage_error:
|
||||
logger.warning(
|
||||
"Failed to delete recording from storage",
|
||||
recording_id=recording_id,
|
||||
object_key=recording["object_key"],
|
||||
error=str(storage_error),
|
||||
)
|
||||
if recording_id:
|
||||
result = await session.execute(
|
||||
select(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
)
|
||||
recording = result.mappings().first()
|
||||
if recording:
|
||||
try:
|
||||
await get_recordings_storage().delete_file(recording["object_key"])
|
||||
except Exception as storage_error:
|
||||
logger.warning(
|
||||
"Failed to delete recording from storage",
|
||||
recording_id=recording_id,
|
||||
object_key=recording["object_key"],
|
||||
error=str(storage_error),
|
||||
)
|
||||
|
||||
await session.execute(
|
||||
delete(RecordingModel).where(
|
||||
RecordingModel.id == recording_id
|
||||
)
|
||||
)
|
||||
stats["recordings_deleted"] += 1
|
||||
logger.info(
|
||||
"Deleted associated recording",
|
||||
recording_id=recording_id,
|
||||
)
|
||||
|
||||
await transcripts_controller.remove_by_id(session, transcript_id)
|
||||
stats["transcripts_deleted"] += 1
|
||||
logger.info(
|
||||
"Deleted transcript",
|
||||
transcript_id=transcript_id,
|
||||
created_at=transcript_data["created_at"].isoformat(),
|
||||
await session.execute(
|
||||
delete(RecordingModel).where(RecordingModel.id == recording_id)
|
||||
)
|
||||
stats["recordings_deleted"] += 1
|
||||
logger.info("Deleted associated recording", recording_id=recording_id)
|
||||
|
||||
await transcripts_controller.remove_by_id(session, transcript_id)
|
||||
stats["transcripts_deleted"] += 1
|
||||
logger.info(
|
||||
"Deleted transcript",
|
||||
transcript_id=transcript_id,
|
||||
created_at=transcript_data["created_at"].isoformat(),
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
@@ -138,7 +84,7 @@ async def delete_single_transcript(
|
||||
|
||||
|
||||
async def cleanup_old_transcripts(
|
||||
session_factory, cutoff_date: datetime, stats: CleanupStats, session=None
|
||||
session: AsyncSession, cutoff_date: datetime, stats: CleanupStats
|
||||
):
|
||||
"""Delete old anonymous transcripts and their associated recordings/meetings."""
|
||||
query = select(
|
||||
@@ -150,23 +96,14 @@ async def cleanup_old_transcripts(
|
||||
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
|
||||
)
|
||||
|
||||
if session:
|
||||
# Use provided session for testing
|
||||
result = await session.execute(query)
|
||||
old_transcripts = result.mappings().all()
|
||||
else:
|
||||
# Use session factory for production
|
||||
async with session_factory() as session:
|
||||
result = await session.execute(query)
|
||||
old_transcripts = result.mappings().all()
|
||||
result = await session.execute(query)
|
||||
old_transcripts = result.mappings().all()
|
||||
|
||||
logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
|
||||
|
||||
for transcript_data in old_transcripts:
|
||||
try:
|
||||
await delete_single_transcript(
|
||||
session_factory, transcript_data, stats, session
|
||||
)
|
||||
await delete_single_transcript(session, transcript_data, stats)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
@@ -190,8 +127,8 @@ def log_cleanup_results(stats: CleanupStats):
|
||||
|
||||
|
||||
async def cleanup_old_public_data(
|
||||
session: AsyncSession,
|
||||
days: PositiveInt | None = None,
|
||||
session=None,
|
||||
) -> CleanupStats | None:
|
||||
if days is None:
|
||||
days = settings.PUBLIC_DATA_RETENTION_DAYS
|
||||
@@ -213,8 +150,7 @@ async def cleanup_old_public_data(
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
session_factory = get_session_factory()
|
||||
await cleanup_old_transcripts(session_factory, cutoff_date, stats, session)
|
||||
await cleanup_old_transcripts(session, cutoff_date, stats)
|
||||
|
||||
log_cleanup_results(stats)
|
||||
return stats
|
||||
@@ -226,4 +162,7 @@ async def cleanup_old_public_data(
|
||||
)
|
||||
@asynctask
|
||||
async def cleanup_old_public_data_task(days: int | None = None):
|
||||
await cleanup_old_public_data(days=days)
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
await cleanup_old_public_data(session, days=days)
|
||||
|
||||
@@ -3,8 +3,10 @@ from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -19,7 +21,9 @@ logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
@asynctask
|
||||
async def sync_room_ics(room_id: str):
|
||||
try:
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
room = await rooms_controller.get_by_id(session, room_id)
|
||||
if not room:
|
||||
logger.warning("Room not found for ICS sync", room_id=room_id)
|
||||
return
|
||||
@@ -59,7 +63,9 @@ async def sync_all_ics_calendars():
|
||||
try:
|
||||
logger.info("Starting sync for all ICS-enabled rooms")
|
||||
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
||||
|
||||
for room in ics_enabled_rooms:
|
||||
@@ -86,10 +92,14 @@ def _should_sync(room) -> bool:
|
||||
MEETING_DEFAULT_DURATION = timedelta(hours=1)
|
||||
|
||||
|
||||
async def create_upcoming_meetings_for_event(event, create_window, room_id, room):
|
||||
async def create_upcoming_meetings_for_event(
|
||||
session: AsyncSession, event, create_window, room_id, room
|
||||
):
|
||||
if event.start_time <= create_window:
|
||||
return
|
||||
existing_meeting = await meetings_controller.get_by_calendar_event(event.id)
|
||||
existing_meeting = await meetings_controller.get_by_calendar_event(
|
||||
session, event.id
|
||||
)
|
||||
|
||||
if existing_meeting:
|
||||
return
|
||||
@@ -112,6 +122,7 @@ async def create_upcoming_meetings_for_event(event, create_window, room_id, room
|
||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||
|
||||
meeting = await meetings_controller.create(
|
||||
session,
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
@@ -155,20 +166,24 @@ async def create_upcoming_meetings():
|
||||
try:
|
||||
logger.info("Starting creation of upcoming meetings")
|
||||
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
|
||||
now = datetime.now(timezone.utc)
|
||||
create_window = now - timedelta(minutes=6)
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||
now = datetime.now(timezone.utc)
|
||||
create_window = now - timedelta(minutes=6)
|
||||
|
||||
for room in ics_enabled_rooms:
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
room.id,
|
||||
minutes_ahead=7,
|
||||
)
|
||||
for room in ics_enabled_rooms:
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
session,
|
||||
room.id,
|
||||
minutes_ahead=7,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
await create_upcoming_meetings_for_event(
|
||||
event, create_window, room.id, room
|
||||
)
|
||||
for event in events:
|
||||
await create_upcoming_meetings_for_event(
|
||||
session, event, create_window, room.id, room
|
||||
)
|
||||
logger.info("Completed pre-creation check for upcoming meetings")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -11,6 +11,7 @@ from celery.utils.log import get_task_logger
|
||||
from pydantic import ValidationError
|
||||
from redis.exceptions import LockError
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.recordings import Recording, recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -82,66 +83,78 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
room_name = f"/{object_key[:36]}"
|
||||
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||
|
||||
meeting = await meetings_controller.get_by_room_name(room_name)
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
meeting = await meetings_controller.get_by_room_name(session, room_name)
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
|
||||
recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
|
||||
if not recording:
|
||||
recording = await recordings_controller.create(
|
||||
Recording(
|
||||
bucket_name=bucket_name,
|
||||
object_key=object_key,
|
||||
recorded_at=recorded_at,
|
||||
meeting_id=meeting.id,
|
||||
recording = await recordings_controller.get_by_object_key(
|
||||
session, bucket_name, object_key
|
||||
)
|
||||
)
|
||||
if not recording:
|
||||
recording = await recordings_controller.create(
|
||||
session,
|
||||
Recording(
|
||||
bucket_name=bucket_name,
|
||||
object_key=object_key,
|
||||
recorded_at=recorded_at,
|
||||
meeting_id=meeting.id,
|
||||
),
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_recording_id(recording.id)
|
||||
if transcript:
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
else:
|
||||
transcript = await transcripts_controller.add(
|
||||
"",
|
||||
source_kind=SourceKind.ROOM,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
user_id=room.user_id,
|
||||
recording_id=recording.id,
|
||||
share_mode="public",
|
||||
meeting_id=meeting.id,
|
||||
room_id=room.id,
|
||||
)
|
||||
transcript = await transcripts_controller.get_by_recording_id(
|
||||
session, recording.id
|
||||
)
|
||||
if transcript:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
else:
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
"",
|
||||
source_kind=SourceKind.ROOM,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
user_id=room.user_id,
|
||||
recording_id=recording.id,
|
||||
share_mode="public",
|
||||
meeting_id=meeting.id,
|
||||
room_id=room.id,
|
||||
)
|
||||
|
||||
_, extension = os.path.splitext(object_key)
|
||||
upload_filename = transcript.data_path / f"upload{extension}"
|
||||
upload_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
_, extension = os.path.splitext(object_key)
|
||||
upload_filename = transcript.data_path / f"upload{extension}"
|
||||
upload_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
s3 = boto3.client(
|
||||
"s3",
|
||||
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
s3 = boto3.client(
|
||||
"s3",
|
||||
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
with open(upload_filename, "wb") as f:
|
||||
s3.download_fileobj(bucket_name, object_key, f)
|
||||
with open(upload_filename, "wb") as f:
|
||||
s3.download_fileobj(bucket_name, object_key, f)
|
||||
|
||||
container = av.open(upload_filename.as_posix())
|
||||
try:
|
||||
if not len(container.streams.audio):
|
||||
raise Exception("File has no audio stream")
|
||||
except Exception:
|
||||
upload_filename.unlink()
|
||||
raise
|
||||
finally:
|
||||
container.close()
|
||||
container = av.open(upload_filename.as_posix())
|
||||
try:
|
||||
if not len(container.streams.audio):
|
||||
raise Exception("File has no audio stream")
|
||||
except Exception:
|
||||
upload_filename.unlink()
|
||||
raise
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"status": "uploaded"}
|
||||
)
|
||||
|
||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||
|
||||
@@ -165,7 +178,10 @@ async def process_meetings():
|
||||
process the same meeting simultaneously.
|
||||
"""
|
||||
logger.info("Processing meetings")
|
||||
meetings = await meetings_controller.get_all_active()
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
meetings = await meetings_controller.get_all_active(session)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
redis_client = get_redis_client()
|
||||
processed_count = 0
|
||||
@@ -218,7 +234,9 @@ async def process_meetings():
|
||||
logger_.debug("Meeting not yet started, keep it")
|
||||
|
||||
if should_deactivate:
|
||||
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
||||
await meetings_controller.update_meeting(
|
||||
session, meeting.id, is_active=False
|
||||
)
|
||||
logger_.info("Meeting is deactivated")
|
||||
|
||||
processed_count += 1
|
||||
@@ -260,40 +278,44 @@ async def reprocess_failed_recordings():
|
||||
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
|
||||
pages = paginator.paginate(Bucket=bucket_name)
|
||||
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
object_key = obj["Key"]
|
||||
|
||||
if not (object_key.endswith(".mp4")):
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
recording = await recordings_controller.get_by_object_key(
|
||||
bucket_name, object_key
|
||||
)
|
||||
if not recording:
|
||||
logger.info(f"Queueing recording for processing: {object_key}")
|
||||
process_recording.delay(bucket_name, object_key)
|
||||
reprocessed_count += 1
|
||||
continue
|
||||
for obj in page["Contents"]:
|
||||
object_key = obj["Key"]
|
||||
|
||||
transcript = None
|
||||
try:
|
||||
transcript = await transcripts_controller.get_by_recording_id(
|
||||
recording.id
|
||||
)
|
||||
except ValidationError:
|
||||
await transcripts_controller.remove_by_recording_id(recording.id)
|
||||
logger.warning(
|
||||
f"Removed invalid transcript for recording: {recording.id}"
|
||||
)
|
||||
if not (object_key.endswith(".mp4")):
|
||||
continue
|
||||
|
||||
if transcript is None or transcript.status == "error":
|
||||
logger.info(f"Queueing recording for processing: {object_key}")
|
||||
process_recording.delay(bucket_name, object_key)
|
||||
reprocessed_count += 1
|
||||
recording = await recordings_controller.get_by_object_key(
|
||||
session, bucket_name, object_key
|
||||
)
|
||||
if not recording:
|
||||
logger.info(f"Queueing recording for processing: {object_key}")
|
||||
process_recording.delay(bucket_name, object_key)
|
||||
reprocessed_count += 1
|
||||
continue
|
||||
|
||||
transcript = None
|
||||
try:
|
||||
transcript = await transcripts_controller.get_by_recording_id(
|
||||
session, recording.id
|
||||
)
|
||||
except ValidationError:
|
||||
await transcripts_controller.remove_by_recording_id(
|
||||
session, recording.id
|
||||
)
|
||||
logger.warning(
|
||||
f"Removed invalid transcript for recording: {recording.id}"
|
||||
)
|
||||
|
||||
if transcript is None or transcript.status == "error":
|
||||
logger.info(f"Queueing recording for processing: {object_key}")
|
||||
process_recording.delay(bucket_name, object_key)
|
||||
reprocessed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking S3 bucket: {str(e)}")
|
||||
|
||||
@@ -15,12 +15,12 @@ from reflector.worker.cleanup import cleanup_old_public_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_old_public_data_skips_when_not_public():
|
||||
async def test_cleanup_old_public_data_skips_when_not_public(session):
|
||||
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
|
||||
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
||||
mock_settings.PUBLIC_MODE = False
|
||||
|
||||
result = await cleanup_old_public_data()
|
||||
result = await cleanup_old_public_data(session)
|
||||
|
||||
# Should return early without doing anything
|
||||
assert result is None
|
||||
@@ -81,7 +81,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
|
||||
mock_delete.return_value = None
|
||||
|
||||
# Run cleanup with test session
|
||||
await cleanup_old_public_data(session=session)
|
||||
await cleanup_old_public_data(session)
|
||||
|
||||
# Verify only old anonymous transcript was deleted
|
||||
assert mock_delete.call_count == 1
|
||||
@@ -162,7 +162,7 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
|
||||
mock_storage.return_value.delete_file = AsyncMock()
|
||||
|
||||
# Run cleanup with test session
|
||||
await cleanup_old_public_data(session=session)
|
||||
await cleanup_old_public_data(session)
|
||||
|
||||
# Verify transcript was deleted
|
||||
result = await session.execute(
|
||||
@@ -226,7 +226,7 @@ async def test_cleanup_handles_errors_gracefully(session):
|
||||
mock_delete.side_effect = [Exception("Delete failed"), None]
|
||||
|
||||
# Run cleanup with test session - should not raise exception
|
||||
await cleanup_old_public_data(session=session)
|
||||
await cleanup_old_public_data(session)
|
||||
|
||||
# Both transcripts should have been attempted to delete
|
||||
assert mock_delete.call_count == 2
|
||||
|
||||
@@ -624,7 +624,10 @@ async def test_pipeline_file_process_no_transcript():
|
||||
|
||||
# Should raise an exception for missing transcript when get_transcript is called
|
||||
with pytest.raises(Exception, match="Transcript not found"):
|
||||
await pipeline.get_transcript()
|
||||
from reflector.db import get_session_factory
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
await pipeline.get_transcript(session)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user