feat: add @with_session decorator for worker task session management

- Create session_decorator.py with @with_session decorator
- Decorator automatically manages database sessions for worker tasks
- Ensures session stays open for entire task execution
- Fixes issue where sessions were closed before being used (e.g., process_meetings)

Applied decorator to all worker tasks:
- process.py: process_recording, process_meetings, reprocess_failed_recordings
- cleanup.py: cleanup_old_public_data_task
- ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings

Benefits:
- Consistent session management across all worker tasks
- No more manual session_factory context management in tasks
- Proper transaction boundaries with automatic begin/commit
- Cleaner, more maintainable code
- Fixes session lifecycle issues in process_meetings
This commit is contained in:
2025-09-23 08:55:26 -06:00
parent 617a1c8b32
commit 8ad1270229
4 changed files with 169 additions and 138 deletions

View File

@@ -15,11 +15,11 @@ from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_recordings_storage from reflector.storage import get_recordings_storage
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__) logger = structlog.get_logger(__name__)
@@ -161,8 +161,6 @@ async def cleanup_old_public_data(
retry_kwargs={"max_retries": 3, "countdown": 300}, retry_kwargs={"max_retries": 3, "countdown": 300},
) )
@asynctask @asynctask
async def cleanup_old_public_data_task(days: int | None = None): @with_session
session_factory = get_session_factory() async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
async with session_factory() as session: await cleanup_old_public_data(session, days=days)
async with session.begin():
await cleanup_old_public_data(session, days=days)

View File

@@ -6,24 +6,23 @@ from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.redis_cache import RedisAsyncLock from reflector.redis_cache import RedisAsyncLock
from reflector.services.ics_sync import SyncStatus, ics_sync_service from reflector.services.ics_sync import SyncStatus, ics_sync_service
from reflector.whereby import create_meeting, upload_logo from reflector.whereby import create_meeting, upload_logo
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
@shared_task @shared_task
@asynctask @asynctask
async def sync_room_ics(room_id: str): @with_session
async def sync_room_ics(session: AsyncSession, room_id: str):
try: try:
session_factory = get_session_factory() room = await rooms_controller.get_by_id(session, room_id)
async with session_factory() as session:
room = await rooms_controller.get_by_id(session, room_id)
if not room: if not room:
logger.warning("Room not found for ICS sync", room_id=room_id) logger.warning("Room not found for ICS sync", room_id=room_id)
return return
@@ -59,13 +58,12 @@ async def sync_room_ics(room_id: str):
@shared_task @shared_task
@asynctask @asynctask
async def sync_all_ics_calendars(): @with_session
async def sync_all_ics_calendars(session: AsyncSession):
try: try:
logger.info("Starting sync for all ICS-enabled rooms") logger.info("Starting sync for all ICS-enabled rooms")
session_factory = get_session_factory() ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
async with session_factory() as session:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled") logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
for room in ics_enabled_rooms: for room in ics_enabled_rooms:
@@ -155,7 +153,8 @@ async def create_upcoming_meetings_for_event(
@shared_task @shared_task
@asynctask @asynctask
async def create_upcoming_meetings(): @with_session
async def create_upcoming_meetings(session: AsyncSession):
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock: async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
if not lock.acquired: if not lock.acquired:
logger.warning( logger.warning(
@@ -166,24 +165,21 @@ async def create_upcoming_meetings():
try: try:
logger.info("Starting creation of upcoming meetings") logger.info("Starting creation of upcoming meetings")
session_factory = get_session_factory() ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
async with session_factory() as session: now = datetime.now(timezone.utc)
async with session.begin(): create_window = now - timedelta(minutes=6)
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)
for room in ics_enabled_rooms: for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming( events = await calendar_events_controller.get_upcoming(
session, session,
room.id, room.id,
minutes_ahead=7, minutes_ahead=7,
) )
for event in events: for event in events:
await create_upcoming_meetings_for_event( await create_upcoming_meetings_for_event(
session, event, create_window, room.id, room session, event, create_window, room.id, room
) )
logger.info("Completed pre-creation check for upcoming meetings") logger.info("Completed pre-creation check for upcoming meetings")
except Exception as e: except Exception as e:

View File

@@ -10,8 +10,8 @@ from celery import shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from pydantic import ValidationError from pydantic import ValidationError
from redis.exceptions import LockError from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_factory
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -21,6 +21,7 @@ from reflector.pipelines.main_live_pipeline import asynctask
from reflector.redis_cache import get_redis_client from reflector.redis_cache import get_redis_client
from reflector.settings import settings from reflector.settings import settings
from reflector.whereby import get_room_sessions from reflector.whereby import get_room_sessions
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
@@ -76,92 +77,91 @@ def process_messages():
@shared_task @shared_task
@asynctask @asynctask
async def process_recording(bucket_name: str, object_key: str): @with_session
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key) logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid and a datetime from the object key # extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}" room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57]) recorded_at = parse_datetime_with_timezone(object_key[37:57])
session_factory = get_session_factory() meeting = await meetings_controller.get_by_room_name(session, room_name)
async with session_factory() as session: if not meeting:
async with session.begin(): logger.warning("Room not found, may be deleted ?", room_name=room_name)
meeting = await meetings_controller.get_by_room_name(session, room_name) return
room = await rooms_controller.get_by_id(session, meeting.room_id)
recording = await recordings_controller.get_by_object_key( room = await rooms_controller.get_by_id(session, meeting.room_id)
session, bucket_name, object_key
)
if not recording:
recording = await recordings_controller.create(
session,
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
),
)
transcript = await transcripts_controller.get_by_recording_id( recording = await recordings_controller.get_by_object_key(
session, recording.id session, bucket_name, object_key
) )
if transcript: if not recording:
await transcripts_controller.update( recording = await recordings_controller.create(
session, session,
transcript, Recording(
{ bucket_name=bucket_name,
"topics": [], object_key=object_key,
}, recorded_at=recorded_at,
) meeting_id=meeting.id,
else: ),
transcript = await transcripts_controller.add( )
session,
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
_, extension = os.path.splitext(object_key) transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
upload_filename = transcript.data_path / f"upload{extension}" if transcript:
upload_filename.parent.mkdir(parents=True, exist_ok=True) await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
else:
transcript = await transcripts_controller.add(
session,
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
s3 = boto3.client( _, extension = os.path.splitext(object_key)
"s3", upload_filename = transcript.data_path / f"upload{extension}"
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, upload_filename.parent.mkdir(parents=True, exist_ok=True)
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
with open(upload_filename, "wb") as f: s3 = boto3.client(
s3.download_fileobj(bucket_name, object_key, f) "s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
container = av.open(upload_filename.as_posix()) with open(upload_filename, "wb") as f:
try: s3.download_fileobj(bucket_name, object_key, f)
if not len(container.streams.audio):
raise Exception("File has no audio stream")
except Exception:
upload_filename.unlink()
raise
finally:
container.close()
await transcripts_controller.update( container = av.open(upload_filename.as_posix())
session, transcript, {"status": "uploaded"} try:
) if not len(container.streams.audio):
raise Exception("File has no audio stream")
except Exception:
upload_filename.unlink()
raise
finally:
container.close()
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
task_pipeline_file_process.delay(transcript_id=transcript.id) task_pipeline_file_process.delay(transcript_id=transcript.id)
@shared_task @shared_task
@asynctask @asynctask
async def process_meetings(): @with_session
async def process_meetings(session: AsyncSession):
""" """
Checks which meetings are still active and deactivates those that have ended. Checks which meetings are still active and deactivates those that have ended.
@@ -178,10 +178,7 @@ async def process_meetings():
process the same meeting simultaneously. process the same meeting simultaneously.
""" """
logger.info("Processing meetings") logger.info("Processing meetings")
session_factory = get_session_factory() meetings = await meetings_controller.get_all_active(session)
async with session_factory() as session:
async with session.begin():
meetings = await meetings_controller.get_all_active(session)
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
redis_client = get_redis_client() redis_client = get_redis_client()
processed_count = 0 processed_count = 0
@@ -258,7 +255,8 @@ async def process_meetings():
@shared_task @shared_task
@asynctask @asynctask
async def reprocess_failed_recordings(): @with_session
async def reprocess_failed_recordings(session: AsyncSession):
""" """
Find recordings in the S3 bucket and check if they have proper transcriptions. Find recordings in the S3 bucket and check if they have proper transcriptions.
If not, requeue them for processing. If not, requeue them for processing.
@@ -278,44 +276,42 @@ async def reprocess_failed_recordings():
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name) pages = paginator.paginate(Bucket=bucket_name)
session_factory = get_session_factory() for page in pages:
async with session_factory() as session: if "Contents" not in page:
for page in pages: continue
if "Contents" not in page:
for obj in page["Contents"]:
object_key = obj["Key"]
if not (object_key.endswith(".mp4")):
continue continue
for obj in page["Contents"]: recording = await recordings_controller.get_by_object_key(
object_key = obj["Key"] session, bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
if not (object_key.endswith(".mp4")): transcript = None
continue try:
transcript = await transcripts_controller.get_by_recording_id(
recording = await recordings_controller.get_by_object_key( session, recording.id
session, bucket_name, object_key )
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
) )
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None if transcript is None or transcript.status == "error":
try: logger.info(f"Queueing recording for processing: {object_key}")
transcript = await transcripts_controller.get_by_recording_id( process_recording.delay(bucket_name, object_key)
session, recording.id reprocessed_count += 1
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e: except Exception as e:
logger.error(f"Error checking S3 bucket: {str(e)}") logger.error(f"Error checking S3 bucket: {str(e)}")

View File

@@ -0,0 +1,41 @@
"""
Session management decorator for async worker tasks.
This decorator ensures that all worker tasks have a properly managed database session
that stays open for the entire duration of the task execution.
"""
import functools
from typing import Any, Callable, TypeVar
from reflector.db import get_session_factory
F = TypeVar("F", bound=Callable[..., Any])
def with_session(func: F) -> F:
"""
Decorator that provides an AsyncSession as the first argument to the decorated function.
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
proper session management throughout the task execution.
Example:
@shared_task
@asynctask
@with_session
async def my_task(session: AsyncSession, arg1: str, arg2: int):
# session is automatically provided and managed
result = await some_controller.get_by_id(session, arg1)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
# Pass session as first argument to the decorated function
return await func(session, *args, **kwargs)
return wrapper