feat: add @with_session decorator for worker task session management

- Create session_decorator.py with @with_session decorator
- Decorator automatically manages database sessions for worker tasks
- Ensures session stays open for entire task execution
- Fixes issue where sessions were closed before being used (e.g., process_meetings)

Applied decorator to all worker tasks:
- process.py: process_recording, process_meetings, reprocess_failed_recordings
- cleanup.py: cleanup_old_public_data_task
- ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings

Benefits:
- Consistent session management across all worker tasks
- No more manual session_factory context management in tasks
- Proper transaction boundaries with automatic begin/commit
- Cleaner, more maintainable code
- Fixes session lifecycle issues in process_meetings
This commit is contained in:
2025-09-23 08:55:26 -06:00
parent 617a1c8b32
commit 8ad1270229
4 changed files with 169 additions and 138 deletions

View File

@@ -15,11 +15,11 @@ from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_recordings_storage from reflector.storage import get_recordings_storage
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__) logger = structlog.get_logger(__name__)
@@ -161,8 +161,6 @@ async def cleanup_old_public_data(
retry_kwargs={"max_retries": 3, "countdown": 300}, retry_kwargs={"max_retries": 3, "countdown": 300},
) )
@asynctask @asynctask
async def cleanup_old_public_data_task(days: int | None = None): @with_session
session_factory = get_session_factory() async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
async with session_factory() as session:
async with session.begin():
await cleanup_old_public_data(session, days=days) await cleanup_old_public_data(session, days=days)

View File

@@ -6,23 +6,22 @@ from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.redis_cache import RedisAsyncLock from reflector.redis_cache import RedisAsyncLock
from reflector.services.ics_sync import SyncStatus, ics_sync_service from reflector.services.ics_sync import SyncStatus, ics_sync_service
from reflector.whereby import create_meeting, upload_logo from reflector.whereby import create_meeting, upload_logo
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
@shared_task @shared_task
@asynctask @asynctask
async def sync_room_ics(room_id: str): @with_session
async def sync_room_ics(session: AsyncSession, room_id: str):
try: try:
session_factory = get_session_factory()
async with session_factory() as session:
room = await rooms_controller.get_by_id(session, room_id) room = await rooms_controller.get_by_id(session, room_id)
if not room: if not room:
logger.warning("Room not found for ICS sync", room_id=room_id) logger.warning("Room not found for ICS sync", room_id=room_id)
@@ -59,12 +58,11 @@ async def sync_room_ics(room_id: str):
@shared_task @shared_task
@asynctask @asynctask
async def sync_all_ics_calendars(): @with_session
async def sync_all_ics_calendars(session: AsyncSession):
try: try:
logger.info("Starting sync for all ICS-enabled rooms") logger.info("Starting sync for all ICS-enabled rooms")
session_factory = get_session_factory()
async with session_factory() as session:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled") logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
@@ -155,7 +153,8 @@ async def create_upcoming_meetings_for_event(
@shared_task @shared_task
@asynctask @asynctask
async def create_upcoming_meetings(): @with_session
async def create_upcoming_meetings(session: AsyncSession):
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock: async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
if not lock.acquired: if not lock.acquired:
logger.warning( logger.warning(
@@ -166,9 +165,6 @@ async def create_upcoming_meetings():
try: try:
logger.info("Starting creation of upcoming meetings") logger.info("Starting creation of upcoming meetings")
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6) create_window = now - timedelta(minutes=6)

View File

@@ -10,8 +10,8 @@ from celery import shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from pydantic import ValidationError from pydantic import ValidationError
from redis.exceptions import LockError from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_factory
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -21,6 +21,7 @@ from reflector.pipelines.main_live_pipeline import asynctask
from reflector.redis_cache import get_redis_client from reflector.redis_cache import get_redis_client
from reflector.settings import settings from reflector.settings import settings
from reflector.whereby import get_room_sessions from reflector.whereby import get_room_sessions
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
@@ -76,17 +77,19 @@ def process_messages():
@shared_task @shared_task
@asynctask @asynctask
async def process_recording(bucket_name: str, object_key: str): @with_session
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key) logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid and a datetime from the object key # extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}" room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57]) recorded_at = parse_datetime_with_timezone(object_key[37:57])
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meeting = await meetings_controller.get_by_room_name(session, room_name) meeting = await meetings_controller.get_by_room_name(session, room_name)
if not meeting:
logger.warning("Room not found, may be deleted ?", room_name=room_name)
return
room = await rooms_controller.get_by_id(session, meeting.room_id) room = await rooms_controller.get_by_id(session, meeting.room_id)
recording = await recordings_controller.get_by_object_key( recording = await recordings_controller.get_by_object_key(
@@ -103,9 +106,7 @@ async def process_recording(bucket_name: str, object_key: str):
), ),
) )
transcript = await transcripts_controller.get_by_recording_id( transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
session, recording.id
)
if transcript: if transcript:
await transcripts_controller.update( await transcripts_controller.update(
session, session,
@@ -152,16 +153,15 @@ async def process_recording(bucket_name: str, object_key: str):
finally: finally:
container.close() container.close()
await transcripts_controller.update( await transcripts_controller.update(session, transcript, {"status": "uploaded"})
session, transcript, {"status": "uploaded"}
)
task_pipeline_file_process.delay(transcript_id=transcript.id) task_pipeline_file_process.delay(transcript_id=transcript.id)
@shared_task @shared_task
@asynctask @asynctask
async def process_meetings(): @with_session
async def process_meetings(session: AsyncSession):
""" """
Checks which meetings are still active and deactivates those that have ended. Checks which meetings are still active and deactivates those that have ended.
@@ -178,9 +178,6 @@ async def process_meetings():
process the same meeting simultaneously. process the same meeting simultaneously.
""" """
logger.info("Processing meetings") logger.info("Processing meetings")
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meetings = await meetings_controller.get_all_active(session) meetings = await meetings_controller.get_all_active(session)
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
redis_client = get_redis_client() redis_client = get_redis_client()
@@ -258,7 +255,8 @@ async def process_meetings():
@shared_task @shared_task
@asynctask @asynctask
async def reprocess_failed_recordings(): @with_session
async def reprocess_failed_recordings(session: AsyncSession):
""" """
Find recordings in the S3 bucket and check if they have proper transcriptions. Find recordings in the S3 bucket and check if they have proper transcriptions.
If not, requeue them for processing. If not, requeue them for processing.
@@ -278,8 +276,6 @@ async def reprocess_failed_recordings():
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name) pages = paginator.paginate(Bucket=bucket_name)
session_factory = get_session_factory()
async with session_factory() as session:
for page in pages: for page in pages:
if "Contents" not in page: if "Contents" not in page:
continue continue

View File

@@ -0,0 +1,41 @@
"""
Session management decorator for async worker tasks.
This decorator ensures that all worker tasks have a properly managed database session
that stays open for the entire duration of the task execution.
"""
import functools
from typing import Any, Callable, TypeVar
from reflector.db import get_session_factory
F = TypeVar("F", bound=Callable[..., Any])
def with_session(func: F) -> F:
"""
Decorator that provides an AsyncSession as the first argument to the decorated function.
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
proper session management throughout the task execution.
Example:
@shared_task
@asynctask
@with_session
async def my_task(session: AsyncSession, arg1: str, arg2: int):
# session is automatically provided and managed
result = await some_controller.get_by_id(session, arg1)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
# Pass session as first argument to the decorated function
return await func(session, *args, **kwargs)
return wrapper