feat: add @with_session decorator for worker task session management

- Create session_decorator.py with @with_session decorator
- Decorator automatically manages database sessions for worker tasks
- Ensures session stays open for entire task execution
- Fixes issue where sessions were closed before being used (e.g., process_meetings)

Applied decorator to all worker tasks:
- process.py: process_recording, process_meetings, reprocess_failed_recordings
- cleanup.py: cleanup_old_public_data_task
- ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings

Benefits:
- Consistent session management across all worker tasks
- No more manual session_factory context management in tasks
- Proper transaction boundaries with automatic begin/commit
- Cleaner, more maintainable code
- Fixes session lifecycle issues in process_meetings
This commit is contained in:
2025-09-23 08:55:26 -06:00
parent 617a1c8b32
commit 8ad1270229
4 changed files with 169 additions and 138 deletions

View File

@@ -15,11 +15,11 @@ from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
from reflector.db.transcripts import transcripts_controller
from reflector.settings import settings
from reflector.storage import get_recordings_storage
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__)
@@ -161,8 +161,6 @@ async def cleanup_old_public_data(
retry_kwargs={"max_retries": 3, "countdown": 300},
)
@asynctask
async def cleanup_old_public_data_task(days: int | None = None):
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
@with_session
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
await cleanup_old_public_data(session, days=days)

View File

@@ -6,23 +6,22 @@ from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.redis_cache import RedisAsyncLock
from reflector.services.ics_sync import SyncStatus, ics_sync_service
from reflector.whereby import create_meeting, upload_logo
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__))
@shared_task
@asynctask
async def sync_room_ics(room_id: str):
@with_session
async def sync_room_ics(session: AsyncSession, room_id: str):
try:
session_factory = get_session_factory()
async with session_factory() as session:
room = await rooms_controller.get_by_id(session, room_id)
if not room:
logger.warning("Room not found for ICS sync", room_id=room_id)
@@ -59,12 +58,11 @@ async def sync_room_ics(room_id: str):
@shared_task
@asynctask
async def sync_all_ics_calendars():
@with_session
async def sync_all_ics_calendars(session: AsyncSession):
try:
logger.info("Starting sync for all ICS-enabled rooms")
session_factory = get_session_factory()
async with session_factory() as session:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
@@ -155,7 +153,8 @@ async def create_upcoming_meetings_for_event(
@shared_task
@asynctask
async def create_upcoming_meetings():
@with_session
async def create_upcoming_meetings(session: AsyncSession):
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
if not lock.acquired:
logger.warning(
@@ -166,9 +165,6 @@ async def create_upcoming_meetings():
try:
logger.info("Starting creation of upcoming meetings")
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)

View File

@@ -10,8 +10,8 @@ from celery import shared_task
from celery.utils.log import get_task_logger
from pydantic import ValidationError
from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_factory
from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
@@ -21,6 +21,7 @@ from reflector.pipelines.main_live_pipeline import asynctask
from reflector.redis_cache import get_redis_client
from reflector.settings import settings
from reflector.whereby import get_room_sessions
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__))
@@ -76,17 +77,19 @@ def process_messages():
@shared_task
@asynctask
async def process_recording(bucket_name: str, object_key: str):
@with_session
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57])
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meeting = await meetings_controller.get_by_room_name(session, room_name)
if not meeting:
logger.warning("Room not found, may be deleted ?", room_name=room_name)
return
room = await rooms_controller.get_by_id(session, meeting.room_id)
recording = await recordings_controller.get_by_object_key(
@@ -103,9 +106,7 @@ async def process_recording(bucket_name: str, object_key: str):
),
)
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
if transcript:
await transcripts_controller.update(
session,
@@ -152,16 +153,15 @@ async def process_recording(bucket_name: str, object_key: str):
finally:
container.close()
await transcripts_controller.update(
session, transcript, {"status": "uploaded"}
)
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
task_pipeline_file_process.delay(transcript_id=transcript.id)
@shared_task
@asynctask
async def process_meetings():
@with_session
async def process_meetings(session: AsyncSession):
"""
Checks which meetings are still active and deactivates those that have ended.
@@ -178,9 +178,6 @@ async def process_meetings():
process the same meeting simultaneously.
"""
logger.info("Processing meetings")
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meetings = await meetings_controller.get_all_active(session)
current_time = datetime.now(timezone.utc)
redis_client = get_redis_client()
@@ -258,7 +255,8 @@ async def process_meetings():
@shared_task
@asynctask
async def reprocess_failed_recordings():
@with_session
async def reprocess_failed_recordings(session: AsyncSession):
"""
Find recordings in the S3 bucket and check if they have proper transcriptions.
If not, requeue them for processing.
@@ -278,8 +276,6 @@ async def reprocess_failed_recordings():
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name)
session_factory = get_session_factory()
async with session_factory() as session:
for page in pages:
if "Contents" not in page:
continue

View File

@@ -0,0 +1,41 @@
"""
Session management decorator for async worker tasks.
This decorator ensures that all worker tasks have a properly managed database session
that stays open for the entire duration of the task execution.
"""
import functools
from typing import Any, Callable, TypeVar
from reflector.db import get_session_factory
F = TypeVar("F", bound=Callable[..., Any])
def with_session(func: F) -> F:
"""
Decorator that provides an AsyncSession as the first argument to the decorated function.
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
proper session management throughout the task execution.
Example:
@shared_task
@asynctask
@with_session
async def my_task(session: AsyncSession, arg1: str, arg2: int):
# session is automatically provided and managed
result = await some_controller.get_by_id(session, arg1)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
# Pass session as first argument to the decorated function
return await func(session, *args, **kwargs)
return wrapper