mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
refactor: improve session management across worker tasks and pipelines
- Remove "if session" anti-pattern from all functions - Functions now require explicit AsyncSession parameters instead of optional session_factory - Worker tasks (Celery) create sessions at top level using session_factory - Add proper AsyncSession type annotations to all session parameters - Update cleanup.py: delete_single_transcript, cleanup_old_transcripts, cleanup_old_public_data - Update process.py: process_recording, process_meetings, reprocess_failed_recordings - Update ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings - Update pipeline classes: get_transcript methods now require session - Fix tests to pass sessions correctly Benefits: - Better type safety and IDE support with explicit AsyncSession typing - Clear transaction boundaries with sessions created at task level - Consistent session management pattern across codebase - No ambiguity about session vs session_factory usage
This commit is contained in:
@@ -87,15 +87,9 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||||
|
|
||||||
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
||||||
"""Get transcript with session"""
|
"""Get transcript with session"""
|
||||||
if session:
|
|
||||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||||
else:
|
|
||||||
async with get_session_factory()() as session:
|
|
||||||
result = await transcripts_controller.get_by_id(
|
|
||||||
session, self.transcript_id
|
|
||||||
)
|
|
||||||
if not result:
|
if not result:
|
||||||
raise Exception("Transcript not found")
|
raise Exception("Transcript not found")
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -142,15 +142,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
self._ws_manager = get_ws_manager()
|
self._ws_manager = get_ws_manager()
|
||||||
return self._ws_manager
|
return self._ws_manager
|
||||||
|
|
||||||
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
||||||
# fetch the transcript
|
# fetch the transcript
|
||||||
if session:
|
|
||||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||||
else:
|
|
||||||
async with get_session_factory()() as session:
|
|
||||||
result = await transcripts_controller.get_by_id(
|
|
||||||
session, self.transcript_id
|
|
||||||
)
|
|
||||||
if not result:
|
if not result:
|
||||||
raise Exception("Transcript not found")
|
raise Exception("Transcript not found")
|
||||||
return result
|
return result
|
||||||
@@ -349,7 +343,8 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
transcript = await self.get_transcript()
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await self.get_transcript(session)
|
||||||
|
|
||||||
processors = [
|
processors = [
|
||||||
AudioFileWriterProcessor(
|
AudioFileWriterProcessor(
|
||||||
@@ -397,7 +392,8 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
|||||||
# now let's start the pipeline by pushing information to the
|
# now let's start the pipeline by pushing information to the
|
||||||
# first processor diarization processor
|
# first processor diarization processor
|
||||||
# XXX translation is lost when converting our data model to the processor model
|
# XXX translation is lost when converting our data model to the processor model
|
||||||
transcript = await self.get_transcript()
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await self.get_transcript(session)
|
||||||
|
|
||||||
# diarization works only if the file is uploaded to an external storage
|
# diarization works only if the file is uploaded to an external storage
|
||||||
if transcript.audio_location == "local":
|
if transcript.audio_location == "local":
|
||||||
@@ -430,7 +426,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
|||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# get transcript
|
# get transcript
|
||||||
self._transcript = transcript = await self.get_transcript()
|
async with get_session_factory()() as session:
|
||||||
|
self._transcript = transcript = await self.get_transcript(session)
|
||||||
|
|
||||||
# create pipeline
|
# create pipeline
|
||||||
processors = self.get_processors()
|
processors = self.get_processors()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import structlog
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from pydantic.types import PositiveInt
|
from pydantic.types import PositiveInt
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
from reflector.db import get_session_factory
|
from reflector.db import get_session_factory
|
||||||
@@ -33,15 +34,13 @@ class CleanupStats(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
async def delete_single_transcript(
|
async def delete_single_transcript(
|
||||||
session_factory, transcript_data: dict, stats: CleanupStats, session=None
|
session: AsyncSession, transcript_data: dict, stats: CleanupStats
|
||||||
):
|
):
|
||||||
transcript_id = transcript_data["id"]
|
transcript_id = transcript_data["id"]
|
||||||
meeting_id = transcript_data["meeting_id"]
|
meeting_id = transcript_data["meeting_id"]
|
||||||
recording_id = transcript_data["recording_id"]
|
recording_id = transcript_data["recording_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if session:
|
|
||||||
# Use provided session for testing - don't start new transaction
|
|
||||||
if meeting_id:
|
if meeting_id:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
||||||
@@ -56,9 +55,7 @@ async def delete_single_transcript(
|
|||||||
recording = result.mappings().first()
|
recording = result.mappings().first()
|
||||||
if recording:
|
if recording:
|
||||||
try:
|
try:
|
||||||
await get_recordings_storage().delete_file(
|
await get_recordings_storage().delete_file(recording["object_key"])
|
||||||
recording["object_key"]
|
|
||||||
)
|
|
||||||
except Exception as storage_error:
|
except Exception as storage_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to delete recording from storage",
|
"Failed to delete recording from storage",
|
||||||
@@ -71,58 +68,7 @@ async def delete_single_transcript(
|
|||||||
delete(RecordingModel).where(RecordingModel.id == recording_id)
|
delete(RecordingModel).where(RecordingModel.id == recording_id)
|
||||||
)
|
)
|
||||||
stats["recordings_deleted"] += 1
|
stats["recordings_deleted"] += 1
|
||||||
logger.info(
|
logger.info("Deleted associated recording", recording_id=recording_id)
|
||||||
"Deleted associated recording", recording_id=recording_id
|
|
||||||
)
|
|
||||||
|
|
||||||
await transcripts_controller.remove_by_id(session, transcript_id)
|
|
||||||
stats["transcripts_deleted"] += 1
|
|
||||||
logger.info(
|
|
||||||
"Deleted transcript",
|
|
||||||
transcript_id=transcript_id,
|
|
||||||
created_at=transcript_data["created_at"].isoformat(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use session factory for production
|
|
||||||
async with session_factory() as session:
|
|
||||||
async with session.begin():
|
|
||||||
if meeting_id:
|
|
||||||
await session.execute(
|
|
||||||
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
|
||||||
)
|
|
||||||
stats["meetings_deleted"] += 1
|
|
||||||
logger.info("Deleted associated meeting", meeting_id=meeting_id)
|
|
||||||
|
|
||||||
if recording_id:
|
|
||||||
result = await session.execute(
|
|
||||||
select(RecordingModel).where(
|
|
||||||
RecordingModel.id == recording_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
recording = result.mappings().first()
|
|
||||||
if recording:
|
|
||||||
try:
|
|
||||||
await get_recordings_storage().delete_file(
|
|
||||||
recording["object_key"]
|
|
||||||
)
|
|
||||||
except Exception as storage_error:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to delete recording from storage",
|
|
||||||
recording_id=recording_id,
|
|
||||||
object_key=recording["object_key"],
|
|
||||||
error=str(storage_error),
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.execute(
|
|
||||||
delete(RecordingModel).where(
|
|
||||||
RecordingModel.id == recording_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stats["recordings_deleted"] += 1
|
|
||||||
logger.info(
|
|
||||||
"Deleted associated recording",
|
|
||||||
recording_id=recording_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
await transcripts_controller.remove_by_id(session, transcript_id)
|
await transcripts_controller.remove_by_id(session, transcript_id)
|
||||||
stats["transcripts_deleted"] += 1
|
stats["transcripts_deleted"] += 1
|
||||||
@@ -138,7 +84,7 @@ async def delete_single_transcript(
|
|||||||
|
|
||||||
|
|
||||||
async def cleanup_old_transcripts(
|
async def cleanup_old_transcripts(
|
||||||
session_factory, cutoff_date: datetime, stats: CleanupStats, session=None
|
session: AsyncSession, cutoff_date: datetime, stats: CleanupStats
|
||||||
):
|
):
|
||||||
"""Delete old anonymous transcripts and their associated recordings/meetings."""
|
"""Delete old anonymous transcripts and their associated recordings/meetings."""
|
||||||
query = select(
|
query = select(
|
||||||
@@ -150,13 +96,6 @@ async def cleanup_old_transcripts(
|
|||||||
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
|
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
|
||||||
)
|
)
|
||||||
|
|
||||||
if session:
|
|
||||||
# Use provided session for testing
|
|
||||||
result = await session.execute(query)
|
|
||||||
old_transcripts = result.mappings().all()
|
|
||||||
else:
|
|
||||||
# Use session factory for production
|
|
||||||
async with session_factory() as session:
|
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
old_transcripts = result.mappings().all()
|
old_transcripts = result.mappings().all()
|
||||||
|
|
||||||
@@ -164,9 +103,7 @@ async def cleanup_old_transcripts(
|
|||||||
|
|
||||||
for transcript_data in old_transcripts:
|
for transcript_data in old_transcripts:
|
||||||
try:
|
try:
|
||||||
await delete_single_transcript(
|
await delete_single_transcript(session, transcript_data, stats)
|
||||||
session_factory, transcript_data, stats, session
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
|
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=e)
|
logger.error(error_msg, exc_info=e)
|
||||||
@@ -190,8 +127,8 @@ def log_cleanup_results(stats: CleanupStats):
|
|||||||
|
|
||||||
|
|
||||||
async def cleanup_old_public_data(
|
async def cleanup_old_public_data(
|
||||||
|
session: AsyncSession,
|
||||||
days: PositiveInt | None = None,
|
days: PositiveInt | None = None,
|
||||||
session=None,
|
|
||||||
) -> CleanupStats | None:
|
) -> CleanupStats | None:
|
||||||
if days is None:
|
if days is None:
|
||||||
days = settings.PUBLIC_DATA_RETENTION_DAYS
|
days = settings.PUBLIC_DATA_RETENTION_DAYS
|
||||||
@@ -213,8 +150,7 @@ async def cleanup_old_public_data(
|
|||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
await cleanup_old_transcripts(session, cutoff_date, stats)
|
||||||
await cleanup_old_transcripts(session_factory, cutoff_date, stats, session)
|
|
||||||
|
|
||||||
log_cleanup_results(stats)
|
log_cleanup_results(stats)
|
||||||
return stats
|
return stats
|
||||||
@@ -226,4 +162,7 @@ async def cleanup_old_public_data(
|
|||||||
)
|
)
|
||||||
@asynctask
|
@asynctask
|
||||||
async def cleanup_old_public_data_task(days: int | None = None):
|
async def cleanup_old_public_data_task(days: int | None = None):
|
||||||
await cleanup_old_public_data(days=days)
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
|
async with session.begin():
|
||||||
|
await cleanup_old_public_data(session, days=days)
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from datetime import datetime, timedelta, timezone
|
|||||||
import structlog
|
import structlog
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
|
from reflector.db import get_session_factory
|
||||||
from reflector.db.calendar_events import calendar_events_controller
|
from reflector.db.calendar_events import calendar_events_controller
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -19,7 +21,9 @@ logger = structlog.wrap_logger(get_task_logger(__name__))
|
|||||||
@asynctask
|
@asynctask
|
||||||
async def sync_room_ics(room_id: str):
|
async def sync_room_ics(room_id: str):
|
||||||
try:
|
try:
|
||||||
room = await rooms_controller.get_by_id(room_id)
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
|
room = await rooms_controller.get_by_id(session, room_id)
|
||||||
if not room:
|
if not room:
|
||||||
logger.warning("Room not found for ICS sync", room_id=room_id)
|
logger.warning("Room not found for ICS sync", room_id=room_id)
|
||||||
return
|
return
|
||||||
@@ -59,7 +63,9 @@ async def sync_all_ics_calendars():
|
|||||||
try:
|
try:
|
||||||
logger.info("Starting sync for all ICS-enabled rooms")
|
logger.info("Starting sync for all ICS-enabled rooms")
|
||||||
|
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
|
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||||
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room in ics_enabled_rooms:
|
||||||
@@ -86,10 +92,14 @@ def _should_sync(room) -> bool:
|
|||||||
MEETING_DEFAULT_DURATION = timedelta(hours=1)
|
MEETING_DEFAULT_DURATION = timedelta(hours=1)
|
||||||
|
|
||||||
|
|
||||||
async def create_upcoming_meetings_for_event(event, create_window, room_id, room):
|
async def create_upcoming_meetings_for_event(
|
||||||
|
session: AsyncSession, event, create_window, room_id, room
|
||||||
|
):
|
||||||
if event.start_time <= create_window:
|
if event.start_time <= create_window:
|
||||||
return
|
return
|
||||||
existing_meeting = await meetings_controller.get_by_calendar_event(event.id)
|
existing_meeting = await meetings_controller.get_by_calendar_event(
|
||||||
|
session, event.id
|
||||||
|
)
|
||||||
|
|
||||||
if existing_meeting:
|
if existing_meeting:
|
||||||
return
|
return
|
||||||
@@ -112,6 +122,7 @@ async def create_upcoming_meetings_for_event(event, create_window, room_id, room
|
|||||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||||
|
|
||||||
meeting = await meetings_controller.create(
|
meeting = await meetings_controller.create(
|
||||||
|
session,
|
||||||
id=whereby_meeting["meetingId"],
|
id=whereby_meeting["meetingId"],
|
||||||
room_name=whereby_meeting["roomName"],
|
room_name=whereby_meeting["roomName"],
|
||||||
room_url=whereby_meeting["roomUrl"],
|
room_url=whereby_meeting["roomUrl"],
|
||||||
@@ -155,19 +166,23 @@ async def create_upcoming_meetings():
|
|||||||
try:
|
try:
|
||||||
logger.info("Starting creation of upcoming meetings")
|
logger.info("Starting creation of upcoming meetings")
|
||||||
|
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
|
async with session.begin():
|
||||||
|
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
create_window = now - timedelta(minutes=6)
|
create_window = now - timedelta(minutes=6)
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room in ics_enabled_rooms:
|
||||||
events = await calendar_events_controller.get_upcoming(
|
events = await calendar_events_controller.get_upcoming(
|
||||||
|
session,
|
||||||
room.id,
|
room.id,
|
||||||
minutes_ahead=7,
|
minutes_ahead=7,
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
await create_upcoming_meetings_for_event(
|
await create_upcoming_meetings_for_event(
|
||||||
event, create_window, room.id, room
|
session, event, create_window, room.id, room
|
||||||
)
|
)
|
||||||
logger.info("Completed pre-creation check for upcoming meetings")
|
logger.info("Completed pre-creation check for upcoming meetings")
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from celery.utils.log import get_task_logger
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from redis.exceptions import LockError
|
from redis.exceptions import LockError
|
||||||
|
|
||||||
|
from reflector.db import get_session_factory
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.recordings import Recording, recordings_controller
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -82,23 +83,32 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
room_name = f"/{object_key[:36]}"
|
room_name = f"/{object_key[:36]}"
|
||||||
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_room_name(room_name)
|
session_factory = get_session_factory()
|
||||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
async with session_factory() as session:
|
||||||
|
async with session.begin():
|
||||||
|
meeting = await meetings_controller.get_by_room_name(session, room_name)
|
||||||
|
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
|
recording = await recordings_controller.get_by_object_key(
|
||||||
|
session, bucket_name, object_key
|
||||||
|
)
|
||||||
if not recording:
|
if not recording:
|
||||||
recording = await recordings_controller.create(
|
recording = await recordings_controller.create(
|
||||||
|
session,
|
||||||
Recording(
|
Recording(
|
||||||
bucket_name=bucket_name,
|
bucket_name=bucket_name,
|
||||||
object_key=object_key,
|
object_key=object_key,
|
||||||
recorded_at=recorded_at,
|
recorded_at=recorded_at,
|
||||||
meeting_id=meeting.id,
|
meeting_id=meeting.id,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_recording_id(recording.id)
|
transcript = await transcripts_controller.get_by_recording_id(
|
||||||
|
session, recording.id
|
||||||
|
)
|
||||||
if transcript:
|
if transcript:
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": [],
|
"topics": [],
|
||||||
@@ -106,6 +116,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
|
session,
|
||||||
"",
|
"",
|
||||||
source_kind=SourceKind.ROOM,
|
source_kind=SourceKind.ROOM,
|
||||||
source_language="en",
|
source_language="en",
|
||||||
@@ -141,7 +152,9 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
finally:
|
finally:
|
||||||
container.close()
|
container.close()
|
||||||
|
|
||||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
await transcripts_controller.update(
|
||||||
|
session, transcript, {"status": "uploaded"}
|
||||||
|
)
|
||||||
|
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||||
|
|
||||||
@@ -165,7 +178,10 @@ async def process_meetings():
|
|||||||
process the same meeting simultaneously.
|
process the same meeting simultaneously.
|
||||||
"""
|
"""
|
||||||
logger.info("Processing meetings")
|
logger.info("Processing meetings")
|
||||||
meetings = await meetings_controller.get_all_active()
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
|
async with session.begin():
|
||||||
|
meetings = await meetings_controller.get_all_active(session)
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
redis_client = get_redis_client()
|
redis_client = get_redis_client()
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
@@ -218,7 +234,9 @@ async def process_meetings():
|
|||||||
logger_.debug("Meeting not yet started, keep it")
|
logger_.debug("Meeting not yet started, keep it")
|
||||||
|
|
||||||
if should_deactivate:
|
if should_deactivate:
|
||||||
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
await meetings_controller.update_meeting(
|
||||||
|
session, meeting.id, is_active=False
|
||||||
|
)
|
||||||
logger_.info("Meeting is deactivated")
|
logger_.info("Meeting is deactivated")
|
||||||
|
|
||||||
processed_count += 1
|
processed_count += 1
|
||||||
@@ -260,6 +278,8 @@ async def reprocess_failed_recordings():
|
|||||||
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
|
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
|
||||||
pages = paginator.paginate(Bucket=bucket_name)
|
pages = paginator.paginate(Bucket=bucket_name)
|
||||||
|
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
for page in pages:
|
for page in pages:
|
||||||
if "Contents" not in page:
|
if "Contents" not in page:
|
||||||
continue
|
continue
|
||||||
@@ -271,7 +291,7 @@ async def reprocess_failed_recordings():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_object_key(
|
recording = await recordings_controller.get_by_object_key(
|
||||||
bucket_name, object_key
|
session, bucket_name, object_key
|
||||||
)
|
)
|
||||||
if not recording:
|
if not recording:
|
||||||
logger.info(f"Queueing recording for processing: {object_key}")
|
logger.info(f"Queueing recording for processing: {object_key}")
|
||||||
@@ -282,10 +302,12 @@ async def reprocess_failed_recordings():
|
|||||||
transcript = None
|
transcript = None
|
||||||
try:
|
try:
|
||||||
transcript = await transcripts_controller.get_by_recording_id(
|
transcript = await transcripts_controller.get_by_recording_id(
|
||||||
recording.id
|
session, recording.id
|
||||||
)
|
)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
await transcripts_controller.remove_by_recording_id(recording.id)
|
await transcripts_controller.remove_by_recording_id(
|
||||||
|
session, recording.id
|
||||||
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Removed invalid transcript for recording: {recording.id}"
|
f"Removed invalid transcript for recording: {recording.id}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,12 +15,12 @@ from reflector.worker.cleanup import cleanup_old_public_data
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cleanup_old_public_data_skips_when_not_public():
|
async def test_cleanup_old_public_data_skips_when_not_public(session):
|
||||||
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
|
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
|
||||||
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
||||||
mock_settings.PUBLIC_MODE = False
|
mock_settings.PUBLIC_MODE = False
|
||||||
|
|
||||||
result = await cleanup_old_public_data()
|
result = await cleanup_old_public_data(session)
|
||||||
|
|
||||||
# Should return early without doing anything
|
# Should return early without doing anything
|
||||||
assert result is None
|
assert result is None
|
||||||
@@ -81,7 +81,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
|
|||||||
mock_delete.return_value = None
|
mock_delete.return_value = None
|
||||||
|
|
||||||
# Run cleanup with test session
|
# Run cleanup with test session
|
||||||
await cleanup_old_public_data(session=session)
|
await cleanup_old_public_data(session)
|
||||||
|
|
||||||
# Verify only old anonymous transcript was deleted
|
# Verify only old anonymous transcript was deleted
|
||||||
assert mock_delete.call_count == 1
|
assert mock_delete.call_count == 1
|
||||||
@@ -162,7 +162,7 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
|
|||||||
mock_storage.return_value.delete_file = AsyncMock()
|
mock_storage.return_value.delete_file = AsyncMock()
|
||||||
|
|
||||||
# Run cleanup with test session
|
# Run cleanup with test session
|
||||||
await cleanup_old_public_data(session=session)
|
await cleanup_old_public_data(session)
|
||||||
|
|
||||||
# Verify transcript was deleted
|
# Verify transcript was deleted
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
@@ -226,7 +226,7 @@ async def test_cleanup_handles_errors_gracefully(session):
|
|||||||
mock_delete.side_effect = [Exception("Delete failed"), None]
|
mock_delete.side_effect = [Exception("Delete failed"), None]
|
||||||
|
|
||||||
# Run cleanup with test session - should not raise exception
|
# Run cleanup with test session - should not raise exception
|
||||||
await cleanup_old_public_data(session=session)
|
await cleanup_old_public_data(session)
|
||||||
|
|
||||||
# Both transcripts should have been attempted to delete
|
# Both transcripts should have been attempted to delete
|
||||||
assert mock_delete.call_count == 2
|
assert mock_delete.call_count == 2
|
||||||
|
|||||||
@@ -624,7 +624,10 @@ async def test_pipeline_file_process_no_transcript():
|
|||||||
|
|
||||||
# Should raise an exception for missing transcript when get_transcript is called
|
# Should raise an exception for missing transcript when get_transcript is called
|
||||||
with pytest.raises(Exception, match="Transcript not found"):
|
with pytest.raises(Exception, match="Transcript not found"):
|
||||||
await pipeline.get_transcript()
|
from reflector.db import get_session_factory
|
||||||
|
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
await pipeline.get_transcript(session)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user