mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: add @with_session decorator for worker task session management
- Create session_decorator.py with @with_session decorator - Decorator automatically manages database sessions for worker tasks - Ensures session stays open for entire task execution - Fixes issue where sessions were closed before being used (e.g., process_meetings) Applied decorator to all worker tasks: - process.py: process_recording, process_meetings, reprocess_failed_recordings - cleanup.py: cleanup_old_public_data_task - ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings Benefits: - Consistent session management across all worker tasks - No more manual session_factory context management in tasks - Proper transaction boundaries with automatic begin/commit - Cleaner, more maintainable code - Fixes session lifecycle issues in process_meetings
This commit is contained in:
@@ -15,11 +15,11 @@ from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_recordings_storage
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@@ -161,8 +161,6 @@ async def cleanup_old_public_data(
|
||||
retry_kwargs={"max_retries": 3, "countdown": 300},
|
||||
)
|
||||
@asynctask
|
||||
async def cleanup_old_public_data_task(days: int | None = None):
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
@with_session
|
||||
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
|
||||
await cleanup_old_public_data(session, days=days)
|
||||
|
||||
@@ -6,23 +6,22 @@ from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.redis_cache import RedisAsyncLock
|
||||
from reflector.services.ics_sync import SyncStatus, ics_sync_service
|
||||
from reflector.whereby import create_meeting, upload_logo
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def sync_room_ics(room_id: str):
|
||||
@with_session
|
||||
async def sync_room_ics(session: AsyncSession, room_id: str):
|
||||
try:
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
room = await rooms_controller.get_by_id(session, room_id)
|
||||
if not room:
|
||||
logger.warning("Room not found for ICS sync", room_id=room_id)
|
||||
@@ -59,12 +58,11 @@ async def sync_room_ics(room_id: str):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def sync_all_ics_calendars():
|
||||
@with_session
|
||||
async def sync_all_ics_calendars(session: AsyncSession):
|
||||
try:
|
||||
logger.info("Starting sync for all ICS-enabled rooms")
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
||||
|
||||
@@ -155,7 +153,8 @@ async def create_upcoming_meetings_for_event(
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def create_upcoming_meetings():
|
||||
@with_session
|
||||
async def create_upcoming_meetings(session: AsyncSession):
|
||||
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
|
||||
if not lock.acquired:
|
||||
logger.warning(
|
||||
@@ -166,9 +165,6 @@ async def create_upcoming_meetings():
|
||||
try:
|
||||
logger.info("Starting creation of upcoming meetings")
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||
now = datetime.now(timezone.utc)
|
||||
create_window = now - timedelta(minutes=6)
|
||||
|
||||
@@ -10,8 +10,8 @@ from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from pydantic import ValidationError
|
||||
from redis.exceptions import LockError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.recordings import Recording, recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -21,6 +21,7 @@ from reflector.pipelines.main_live_pipeline import asynctask
|
||||
from reflector.redis_cache import get_redis_client
|
||||
from reflector.settings import settings
|
||||
from reflector.whereby import get_room_sessions
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
@@ -76,17 +77,19 @@ def process_messages():
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def process_recording(bucket_name: str, object_key: str):
|
||||
@with_session
|
||||
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
|
||||
logger.info("Processing recording: %s/%s", bucket_name, object_key)
|
||||
|
||||
# extract a guid and a datetime from the object key
|
||||
room_name = f"/{object_key[:36]}"
|
||||
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
meeting = await meetings_controller.get_by_room_name(session, room_name)
|
||||
if not meeting:
|
||||
logger.warning("Room not found, may be deleted ?", room_name=room_name)
|
||||
return
|
||||
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
|
||||
recording = await recordings_controller.get_by_object_key(
|
||||
@@ -103,9 +106,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
),
|
||||
)
|
||||
|
||||
transcript = await transcripts_controller.get_by_recording_id(
|
||||
session, recording.id
|
||||
)
|
||||
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
|
||||
if transcript:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
@@ -152,16 +153,15 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"status": "uploaded"}
|
||||
)
|
||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||
|
||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def process_meetings():
|
||||
@with_session
|
||||
async def process_meetings(session: AsyncSession):
|
||||
"""
|
||||
Checks which meetings are still active and deactivates those that have ended.
|
||||
|
||||
@@ -178,9 +178,6 @@ async def process_meetings():
|
||||
process the same meeting simultaneously.
|
||||
"""
|
||||
logger.info("Processing meetings")
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
meetings = await meetings_controller.get_all_active(session)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
redis_client = get_redis_client()
|
||||
@@ -258,7 +255,8 @@ async def process_meetings():
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def reprocess_failed_recordings():
|
||||
@with_session
|
||||
async def reprocess_failed_recordings(session: AsyncSession):
|
||||
"""
|
||||
Find recordings in the S3 bucket and check if they have proper transcriptions.
|
||||
If not, requeue them for processing.
|
||||
@@ -278,8 +276,6 @@ async def reprocess_failed_recordings():
|
||||
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
|
||||
pages = paginator.paginate(Bucket=bucket_name)
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
41
server/reflector/worker/session_decorator.py
Normal file
41
server/reflector/worker/session_decorator.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Session management decorator for async worker tasks.
|
||||
|
||||
This decorator ensures that all worker tasks have a properly managed database session
|
||||
that stays open for the entire duration of the task execution.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def with_session(func: F) -> F:
|
||||
"""
|
||||
Decorator that provides an AsyncSession as the first argument to the decorated function.
|
||||
|
||||
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
|
||||
proper session management throughout the task execution.
|
||||
|
||||
Example:
|
||||
@shared_task
|
||||
@asynctask
|
||||
@with_session
|
||||
async def my_task(session: AsyncSession, arg1: str, arg2: int):
|
||||
# session is automatically provided and managed
|
||||
result = await some_controller.get_by_id(session, arg1)
|
||||
...
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
async with session.begin():
|
||||
# Pass session as first argument to the decorated function
|
||||
return await func(session, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user