feat: add @with_session decorator for worker task session management

- Create session_decorator.py with @with_session decorator
- Decorator automatically manages database sessions for worker tasks
- Ensures session stays open for entire task execution
- Fixes issue where sessions were closed before being used (e.g., process_meetings)

Applied decorator to all worker tasks:
- process.py: process_recording, process_meetings, reprocess_failed_recordings
- cleanup.py: cleanup_old_public_data_task
- ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings

Benefits:
- Consistent session management across all worker tasks
- No more manual session_factory context management in tasks
- Proper transaction boundaries with automatic begin/commit
- Cleaner, more maintainable code
- Fixes session lifecycle issues in process_meetings
This commit is contained in:
2025-09-23 08:55:26 -06:00
parent 617a1c8b32
commit 8ad1270229
4 changed files with 169 additions and 138 deletions

View File

@@ -15,11 +15,11 @@ from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
from reflector.db.transcripts import transcripts_controller
from reflector.settings import settings
from reflector.storage import get_recordings_storage
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__)
@@ -161,8 +161,6 @@ async def cleanup_old_public_data(
retry_kwargs={"max_retries": 3, "countdown": 300},
)
@asynctask
async def cleanup_old_public_data_task(days: int | None = None):
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
await cleanup_old_public_data(session, days=days)
@with_session
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
await cleanup_old_public_data(session, days=days)

View File

@@ -6,24 +6,23 @@ from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.redis_cache import RedisAsyncLock
from reflector.services.ics_sync import SyncStatus, ics_sync_service
from reflector.whereby import create_meeting, upload_logo
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__))
@shared_task
@asynctask
async def sync_room_ics(room_id: str):
@with_session
async def sync_room_ics(session: AsyncSession, room_id: str):
try:
session_factory = get_session_factory()
async with session_factory() as session:
room = await rooms_controller.get_by_id(session, room_id)
room = await rooms_controller.get_by_id(session, room_id)
if not room:
logger.warning("Room not found for ICS sync", room_id=room_id)
return
@@ -59,13 +58,12 @@ async def sync_room_ics(room_id: str):
@shared_task
@asynctask
async def sync_all_ics_calendars():
@with_session
async def sync_all_ics_calendars(session: AsyncSession):
try:
logger.info("Starting sync for all ICS-enabled rooms")
session_factory = get_session_factory()
async with session_factory() as session:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
for room in ics_enabled_rooms:
@@ -155,7 +153,8 @@ async def create_upcoming_meetings_for_event(
@shared_task
@asynctask
async def create_upcoming_meetings():
@with_session
async def create_upcoming_meetings(session: AsyncSession):
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
if not lock.acquired:
logger.warning(
@@ -166,24 +165,21 @@ async def create_upcoming_meetings():
try:
logger.info("Starting creation of upcoming meetings")
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)
for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming(
session,
room.id,
minutes_ahead=7,
)
for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming(
session,
room.id,
minutes_ahead=7,
)
for event in events:
await create_upcoming_meetings_for_event(
session, event, create_window, room.id, room
)
for event in events:
await create_upcoming_meetings_for_event(
session, event, create_window, room.id, room
)
logger.info("Completed pre-creation check for upcoming meetings")
except Exception as e:

View File

@@ -10,8 +10,8 @@ from celery import shared_task
from celery.utils.log import get_task_logger
from pydantic import ValidationError
from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_factory
from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
@@ -21,6 +21,7 @@ from reflector.pipelines.main_live_pipeline import asynctask
from reflector.redis_cache import get_redis_client
from reflector.settings import settings
from reflector.whereby import get_room_sessions
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__))
@@ -76,92 +77,91 @@ def process_messages():
@shared_task
@asynctask
async def process_recording(bucket_name: str, object_key: str):
@with_session
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57])
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meeting = await meetings_controller.get_by_room_name(session, room_name)
room = await rooms_controller.get_by_id(session, meeting.room_id)
meeting = await meetings_controller.get_by_room_name(session, room_name)
if not meeting:
logger.warning("Room not found, may be deleted ?", room_name=room_name)
return
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
if not recording:
recording = await recordings_controller.create(
session,
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
),
)
room = await rooms_controller.get_by_id(session, meeting.room_id)
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
if transcript:
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
else:
transcript = await transcripts_controller.add(
session,
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
if not recording:
recording = await recordings_controller.create(
session,
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
),
)
_, extension = os.path.splitext(object_key)
upload_filename = transcript.data_path / f"upload{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True)
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
if transcript:
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
else:
transcript = await transcripts_controller.add(
session,
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
s3 = boto3.client(
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
_, extension = os.path.splitext(object_key)
upload_filename = transcript.data_path / f"upload{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True)
with open(upload_filename, "wb") as f:
s3.download_fileobj(bucket_name, object_key, f)
s3 = boto3.client(
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
container = av.open(upload_filename.as_posix())
try:
if not len(container.streams.audio):
raise Exception("File has no audio stream")
except Exception:
upload_filename.unlink()
raise
finally:
container.close()
with open(upload_filename, "wb") as f:
s3.download_fileobj(bucket_name, object_key, f)
await transcripts_controller.update(
session, transcript, {"status": "uploaded"}
)
container = av.open(upload_filename.as_posix())
try:
if not len(container.streams.audio):
raise Exception("File has no audio stream")
except Exception:
upload_filename.unlink()
raise
finally:
container.close()
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
task_pipeline_file_process.delay(transcript_id=transcript.id)
@shared_task
@asynctask
async def process_meetings():
@with_session
async def process_meetings(session: AsyncSession):
"""
Checks which meetings are still active and deactivates those that have ended.
@@ -178,10 +178,7 @@ async def process_meetings():
process the same meeting simultaneously.
"""
logger.info("Processing meetings")
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
meetings = await meetings_controller.get_all_active(session)
meetings = await meetings_controller.get_all_active(session)
current_time = datetime.now(timezone.utc)
redis_client = get_redis_client()
processed_count = 0
@@ -258,7 +255,8 @@ async def process_meetings():
@shared_task
@asynctask
async def reprocess_failed_recordings():
@with_session
async def reprocess_failed_recordings(session: AsyncSession):
"""
Find recordings in the S3 bucket and check if they have proper transcriptions.
If not, requeue them for processing.
@@ -278,44 +276,42 @@ async def reprocess_failed_recordings():
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name)
session_factory = get_session_factory()
async with session_factory() as session:
for page in pages:
if "Contents" not in page:
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
object_key = obj["Key"]
if not (object_key.endswith(".mp4")):
continue
for obj in page["Contents"]:
object_key = obj["Key"]
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
if not (object_key.endswith(".mp4")):
continue
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
transcript = None
try:
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None
try:
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e:
logger.error(f"Error checking S3 bucket: {str(e)}")

View File

@@ -0,0 +1,41 @@
"""
Session management decorator for async worker tasks.
This decorator ensures that all worker tasks have a properly managed database session
that stays open for the entire duration of the task execution.
"""
import functools
from typing import Any, Callable, TypeVar
from reflector.db import get_session_factory
F = TypeVar("F", bound=Callable[..., Any])
def with_session(func: F) -> F:
"""
Decorator that provides an AsyncSession as the first argument to the decorated function.
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
proper session management throughout the task execution.
Example:
@shared_task
@asynctask
@with_session
async def my_task(session: AsyncSession, arg1: str, arg2: int):
# session is automatically provided and managed
result = await some_controller.get_by_id(session, arg1)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
# Pass session as first argument to the decorated function
return await func(session, *args, **kwargs)
return wrapper