Files
reflector/server/reflector/worker/process.py
Mathieu Virbel 8ad1270229 feat: add @with_session decorator for worker task session management
- Create session_decorator.py with @with_session decorator
- Decorator automatically manages database sessions for worker tasks
- Ensures session stays open for entire task execution
- Fixes issue where sessions were closed before being used (e.g., process_meetings)

Applied decorator to all worker tasks:
- process.py: process_recording, process_meetings, reprocess_failed_recordings
- cleanup.py: cleanup_old_public_data_task
- ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings

Benefits:
- Consistent session management across all worker tasks
- No more manual session_factory context management in tasks
- Proper transaction boundaries with automatic begin/commit
- Cleaner, more maintainable code
- Fixes session lifecycle issues in process_meetings
2025-09-23 08:55:26 -06:00

321 lines
11 KiB
Python

import json
import os
from datetime import datetime, timezone
from urllib.parse import unquote
import av
import boto3
import structlog
from celery import shared_task
from celery.utils.log import get_task_logger
from pydantic import ValidationError
from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_live_pipeline import asynctask
from reflector.redis_cache import get_redis_client
from reflector.settings import settings
from reflector.whereby import get_room_sessions
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__))
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
@shared_task
def process_messages():
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
if not queue_url:
logger.warning("No process recording queue url")
return
try:
logger.info("Receiving messages from: %s", queue_url)
sqs = boto3.client(
"sqs",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
response = sqs.receive_message(
QueueUrl=queue_url,
AttributeNames=["SentTimestamp"],
MaxNumberOfMessages=1,
MessageAttributeNames=["All"],
VisibilityTimeout=0,
WaitTimeSeconds=0,
)
for message in response.get("Messages", []):
receipt_handle = message["ReceiptHandle"]
body = json.loads(message["Body"])
for record in body.get("Records", []):
if record["eventName"].startswith("ObjectCreated"):
bucket = record["s3"]["bucket"]["name"]
key = unquote(record["s3"]["object"]["key"])
process_recording.delay(bucket, key)
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
logger.info("Processed and deleted message: %s", message)
except Exception as e:
logger.error("process_messages", error=str(e))
@shared_task
@asynctask
@with_session
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(session, room_name)
if not meeting:
logger.warning("Room not found, may be deleted ?", room_name=room_name)
return
room = await rooms_controller.get_by_id(session, meeting.room_id)
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
if not recording:
recording = await recordings_controller.create(
session,
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
),
)
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
if transcript:
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
else:
transcript = await transcripts_controller.add(
session,
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
_, extension = os.path.splitext(object_key)
upload_filename = transcript.data_path / f"upload{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True)
s3 = boto3.client(
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
with open(upload_filename, "wb") as f:
s3.download_fileobj(bucket_name, object_key, f)
container = av.open(upload_filename.as_posix())
try:
if not len(container.streams.audio):
raise Exception("File has no audio stream")
except Exception:
upload_filename.unlink()
raise
finally:
container.close()
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
task_pipeline_file_process.delay(transcript_id=transcript.id)
@shared_task
@asynctask
@with_session
async def process_meetings(session: AsyncSession):
"""
Checks which meetings are still active and deactivates those that have ended.
Deactivation logic:
- Active sessions: Keep meeting active regardless of scheduled time
- No active sessions:
* Calendar meetings:
- If previously used (had sessions): Deactivate immediately
- If never used: Keep active until scheduled end time, then deactivate
* On-the-fly meetings: Deactivate immediately (created when someone joins,
so no sessions means everyone left)
Uses distributed locking to prevent race conditions when multiple workers
process the same meeting simultaneously.
"""
logger.info("Processing meetings")
meetings = await meetings_controller.get_all_active(session)
current_time = datetime.now(timezone.utc)
redis_client = get_redis_client()
processed_count = 0
skipped_count = 0
for meeting in meetings:
logger_ = logger.bind(meeting_id=meeting.id, room_name=meeting.room_name)
lock_key = f"meeting_process_lock:{meeting.id}"
lock = redis_client.lock(lock_key, timeout=120)
try:
if not lock.acquire(blocking=False):
logger_.debug("Meeting is being processed by another worker, skipping")
skipped_count += 1
continue
# Process the meeting
should_deactivate = False
end_date = meeting.end_date
if end_date.tzinfo is None:
end_date = end_date.replace(tzinfo=timezone.utc)
# This API call could be slow, extend lock if needed
response = await get_room_sessions(meeting.room_name)
try:
# Extend lock after slow operation to ensure we still hold it
lock.extend(120, replace_ttl=True)
except LockError:
logger_.warning("Lost lock for meeting, skipping")
continue
room_sessions = response.get("results", [])
has_active_sessions = room_sessions and any(
rs["endedAt"] is None for rs in room_sessions
)
has_had_sessions = bool(room_sessions)
if has_active_sessions:
logger_.debug("Meeting still has active sessions, keep it")
elif has_had_sessions:
should_deactivate = True
logger_.info("Meeting ended - all participants left")
elif current_time > end_date:
should_deactivate = True
logger_.info(
"Meeting deactivated - scheduled time ended with no participants",
)
else:
logger_.debug("Meeting not yet started, keep it")
if should_deactivate:
await meetings_controller.update_meeting(
session, meeting.id, is_active=False
)
logger_.info("Meeting is deactivated")
processed_count += 1
except Exception:
logger_.error("Error processing meeting", exc_info=True)
finally:
try:
lock.release()
except LockError:
pass # Lock already released or expired
logger.info(
"Processed meetings finished",
processed_count=processed_count,
skipped_count=skipped_count,
)
@shared_task
@asynctask
@with_session
async def reprocess_failed_recordings(session: AsyncSession):
"""
Find recordings in the S3 bucket and check if they have proper transcriptions.
If not, requeue them for processing.
"""
logger.info("Checking for recordings that need processing or reprocessing")
s3 = boto3.client(
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
reprocessed_count = 0
try:
paginator = s3.get_paginator("list_objects_v2")
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name)
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
object_key = obj["Key"]
if not (object_key.endswith(".mp4")):
continue
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None
try:
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e:
logger.error(f"Error checking S3 bucket: {str(e)}")
logger.info(f"Reprocessing complete. Requeued {reprocessed_count} recordings")
return reprocessed_count