mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: add @with_session decorator for worker task session management
- Create session_decorator.py with @with_session decorator - Decorator automatically manages database sessions for worker tasks - Ensures session stays open for entire task execution - Fixes issue where sessions were closed before being used (e.g., process_meetings) Applied decorator to all worker tasks: - process.py: process_recording, process_meetings, reprocess_failed_recordings - cleanup.py: cleanup_old_public_data_task - ics_sync.py: sync_room_ics, sync_all_ics_calendars, create_upcoming_meetings Benefits: - Consistent session management across all worker tasks - No more manual session_factory context management in tasks - Proper transaction boundaries with automatic begin/commit - Cleaner, more maintainable code - Fixes session lifecycle issues in process_meetings
This commit is contained in:
@@ -15,11 +15,11 @@ from sqlalchemy import delete, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
from reflector.db import get_session_factory
|
|
||||||
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
|
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage import get_recordings_storage
|
from reflector.storage import get_recordings_storage
|
||||||
|
from reflector.worker.session_decorator import with_session
|
||||||
|
|
||||||
logger = structlog.get_logger(__name__)
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
@@ -161,8 +161,6 @@ async def cleanup_old_public_data(
|
|||||||
retry_kwargs={"max_retries": 3, "countdown": 300},
|
retry_kwargs={"max_retries": 3, "countdown": 300},
|
||||||
)
|
)
|
||||||
@asynctask
|
@asynctask
|
||||||
async def cleanup_old_public_data_task(days: int | None = None):
|
@with_session
|
||||||
session_factory = get_session_factory()
|
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
|
||||||
async with session_factory() as session:
|
await cleanup_old_public_data(session, days=days)
|
||||||
async with session.begin():
|
|
||||||
await cleanup_old_public_data(session, days=days)
|
|
||||||
|
|||||||
@@ -6,24 +6,23 @@ from celery.utils.log import get_task_logger
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
from reflector.db import get_session_factory
|
|
||||||
from reflector.db.calendar_events import calendar_events_controller
|
from reflector.db.calendar_events import calendar_events_controller
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.redis_cache import RedisAsyncLock
|
from reflector.redis_cache import RedisAsyncLock
|
||||||
from reflector.services.ics_sync import SyncStatus, ics_sync_service
|
from reflector.services.ics_sync import SyncStatus, ics_sync_service
|
||||||
from reflector.whereby import create_meeting, upload_logo
|
from reflector.whereby import create_meeting, upload_logo
|
||||||
|
from reflector.worker.session_decorator import with_session
|
||||||
|
|
||||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def sync_room_ics(room_id: str):
|
@with_session
|
||||||
|
async def sync_room_ics(session: AsyncSession, room_id: str):
|
||||||
try:
|
try:
|
||||||
session_factory = get_session_factory()
|
room = await rooms_controller.get_by_id(session, room_id)
|
||||||
async with session_factory() as session:
|
|
||||||
room = await rooms_controller.get_by_id(session, room_id)
|
|
||||||
if not room:
|
if not room:
|
||||||
logger.warning("Room not found for ICS sync", room_id=room_id)
|
logger.warning("Room not found for ICS sync", room_id=room_id)
|
||||||
return
|
return
|
||||||
@@ -59,13 +58,12 @@ async def sync_room_ics(room_id: str):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def sync_all_ics_calendars():
|
@with_session
|
||||||
|
async def sync_all_ics_calendars(session: AsyncSession):
|
||||||
try:
|
try:
|
||||||
logger.info("Starting sync for all ICS-enabled rooms")
|
logger.info("Starting sync for all ICS-enabled rooms")
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||||
async with session_factory() as session:
|
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
|
||||||
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room in ics_enabled_rooms:
|
||||||
@@ -155,7 +153,8 @@ async def create_upcoming_meetings_for_event(
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def create_upcoming_meetings():
|
@with_session
|
||||||
|
async def create_upcoming_meetings(session: AsyncSession):
|
||||||
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
|
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
|
||||||
if not lock.acquired:
|
if not lock.acquired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -166,24 +165,21 @@ async def create_upcoming_meetings():
|
|||||||
try:
|
try:
|
||||||
logger.info("Starting creation of upcoming meetings")
|
logger.info("Starting creation of upcoming meetings")
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
||||||
async with session_factory() as session:
|
now = datetime.now(timezone.utc)
|
||||||
async with session.begin():
|
create_window = now - timedelta(minutes=6)
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
create_window = now - timedelta(minutes=6)
|
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room in ics_enabled_rooms:
|
||||||
events = await calendar_events_controller.get_upcoming(
|
events = await calendar_events_controller.get_upcoming(
|
||||||
session,
|
session,
|
||||||
room.id,
|
room.id,
|
||||||
minutes_ahead=7,
|
minutes_ahead=7,
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
await create_upcoming_meetings_for_event(
|
await create_upcoming_meetings_for_event(
|
||||||
session, event, create_window, room.id, room
|
session, event, create_window, room.id, room
|
||||||
)
|
)
|
||||||
logger.info("Completed pre-creation check for upcoming meetings")
|
logger.info("Completed pre-creation check for upcoming meetings")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from celery import shared_task
|
|||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from redis.exceptions import LockError
|
from redis.exceptions import LockError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from reflector.db import get_session_factory
|
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.recordings import Recording, recordings_controller
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -21,6 +21,7 @@ from reflector.pipelines.main_live_pipeline import asynctask
|
|||||||
from reflector.redis_cache import get_redis_client
|
from reflector.redis_cache import get_redis_client
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.whereby import get_room_sessions
|
from reflector.whereby import get_room_sessions
|
||||||
|
from reflector.worker.session_decorator import with_session
|
||||||
|
|
||||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||||
|
|
||||||
@@ -76,92 +77,91 @@ def process_messages():
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def process_recording(bucket_name: str, object_key: str):
|
@with_session
|
||||||
|
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
|
||||||
logger.info("Processing recording: %s/%s", bucket_name, object_key)
|
logger.info("Processing recording: %s/%s", bucket_name, object_key)
|
||||||
|
|
||||||
# extract a guid and a datetime from the object key
|
# extract a guid and a datetime from the object key
|
||||||
room_name = f"/{object_key[:36]}"
|
room_name = f"/{object_key[:36]}"
|
||||||
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
meeting = await meetings_controller.get_by_room_name(session, room_name)
|
||||||
async with session_factory() as session:
|
if not meeting:
|
||||||
async with session.begin():
|
logger.warning("Room not found, may be deleted ?", room_name=room_name)
|
||||||
meeting = await meetings_controller.get_by_room_name(session, room_name)
|
return
|
||||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_object_key(
|
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||||
session, bucket_name, object_key
|
|
||||||
)
|
|
||||||
if not recording:
|
|
||||||
recording = await recordings_controller.create(
|
|
||||||
session,
|
|
||||||
Recording(
|
|
||||||
bucket_name=bucket_name,
|
|
||||||
object_key=object_key,
|
|
||||||
recorded_at=recorded_at,
|
|
||||||
meeting_id=meeting.id,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_recording_id(
|
recording = await recordings_controller.get_by_object_key(
|
||||||
session, recording.id
|
session, bucket_name, object_key
|
||||||
)
|
)
|
||||||
if transcript:
|
if not recording:
|
||||||
await transcripts_controller.update(
|
recording = await recordings_controller.create(
|
||||||
session,
|
session,
|
||||||
transcript,
|
Recording(
|
||||||
{
|
bucket_name=bucket_name,
|
||||||
"topics": [],
|
object_key=object_key,
|
||||||
},
|
recorded_at=recorded_at,
|
||||||
)
|
meeting_id=meeting.id,
|
||||||
else:
|
),
|
||||||
transcript = await transcripts_controller.add(
|
)
|
||||||
session,
|
|
||||||
"",
|
|
||||||
source_kind=SourceKind.ROOM,
|
|
||||||
source_language="en",
|
|
||||||
target_language="en",
|
|
||||||
user_id=room.user_id,
|
|
||||||
recording_id=recording.id,
|
|
||||||
share_mode="public",
|
|
||||||
meeting_id=meeting.id,
|
|
||||||
room_id=room.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
_, extension = os.path.splitext(object_key)
|
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
|
||||||
upload_filename = transcript.data_path / f"upload{extension}"
|
if transcript:
|
||||||
upload_filename.parent.mkdir(parents=True, exist_ok=True)
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"topics": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
session,
|
||||||
|
"",
|
||||||
|
source_kind=SourceKind.ROOM,
|
||||||
|
source_language="en",
|
||||||
|
target_language="en",
|
||||||
|
user_id=room.user_id,
|
||||||
|
recording_id=recording.id,
|
||||||
|
share_mode="public",
|
||||||
|
meeting_id=meeting.id,
|
||||||
|
room_id=room.id,
|
||||||
|
)
|
||||||
|
|
||||||
s3 = boto3.client(
|
_, extension = os.path.splitext(object_key)
|
||||||
"s3",
|
upload_filename = transcript.data_path / f"upload{extension}"
|
||||||
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
upload_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
|
||||||
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(upload_filename, "wb") as f:
|
s3 = boto3.client(
|
||||||
s3.download_fileobj(bucket_name, object_key, f)
|
"s3",
|
||||||
|
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
|
||||||
|
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
container = av.open(upload_filename.as_posix())
|
with open(upload_filename, "wb") as f:
|
||||||
try:
|
s3.download_fileobj(bucket_name, object_key, f)
|
||||||
if not len(container.streams.audio):
|
|
||||||
raise Exception("File has no audio stream")
|
|
||||||
except Exception:
|
|
||||||
upload_filename.unlink()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
container.close()
|
|
||||||
|
|
||||||
await transcripts_controller.update(
|
container = av.open(upload_filename.as_posix())
|
||||||
session, transcript, {"status": "uploaded"}
|
try:
|
||||||
)
|
if not len(container.streams.audio):
|
||||||
|
raise Exception("File has no audio stream")
|
||||||
|
except Exception:
|
||||||
|
upload_filename.unlink()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def process_meetings():
|
@with_session
|
||||||
|
async def process_meetings(session: AsyncSession):
|
||||||
"""
|
"""
|
||||||
Checks which meetings are still active and deactivates those that have ended.
|
Checks which meetings are still active and deactivates those that have ended.
|
||||||
|
|
||||||
@@ -178,10 +178,7 @@ async def process_meetings():
|
|||||||
process the same meeting simultaneously.
|
process the same meeting simultaneously.
|
||||||
"""
|
"""
|
||||||
logger.info("Processing meetings")
|
logger.info("Processing meetings")
|
||||||
session_factory = get_session_factory()
|
meetings = await meetings_controller.get_all_active(session)
|
||||||
async with session_factory() as session:
|
|
||||||
async with session.begin():
|
|
||||||
meetings = await meetings_controller.get_all_active(session)
|
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
redis_client = get_redis_client()
|
redis_client = get_redis_client()
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
@@ -258,7 +255,8 @@ async def process_meetings():
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def reprocess_failed_recordings():
|
@with_session
|
||||||
|
async def reprocess_failed_recordings(session: AsyncSession):
|
||||||
"""
|
"""
|
||||||
Find recordings in the S3 bucket and check if they have proper transcriptions.
|
Find recordings in the S3 bucket and check if they have proper transcriptions.
|
||||||
If not, requeue them for processing.
|
If not, requeue them for processing.
|
||||||
@@ -278,44 +276,42 @@ async def reprocess_failed_recordings():
|
|||||||
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
|
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
|
||||||
pages = paginator.paginate(Bucket=bucket_name)
|
pages = paginator.paginate(Bucket=bucket_name)
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
for page in pages:
|
||||||
async with session_factory() as session:
|
if "Contents" not in page:
|
||||||
for page in pages:
|
continue
|
||||||
if "Contents" not in page:
|
|
||||||
|
for obj in page["Contents"]:
|
||||||
|
object_key = obj["Key"]
|
||||||
|
|
||||||
|
if not (object_key.endswith(".mp4")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for obj in page["Contents"]:
|
recording = await recordings_controller.get_by_object_key(
|
||||||
object_key = obj["Key"]
|
session, bucket_name, object_key
|
||||||
|
)
|
||||||
|
if not recording:
|
||||||
|
logger.info(f"Queueing recording for processing: {object_key}")
|
||||||
|
process_recording.delay(bucket_name, object_key)
|
||||||
|
reprocessed_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
if not (object_key.endswith(".mp4")):
|
transcript = None
|
||||||
continue
|
try:
|
||||||
|
transcript = await transcripts_controller.get_by_recording_id(
|
||||||
recording = await recordings_controller.get_by_object_key(
|
session, recording.id
|
||||||
session, bucket_name, object_key
|
)
|
||||||
|
except ValidationError:
|
||||||
|
await transcripts_controller.remove_by_recording_id(
|
||||||
|
session, recording.id
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Removed invalid transcript for recording: {recording.id}"
|
||||||
)
|
)
|
||||||
if not recording:
|
|
||||||
logger.info(f"Queueing recording for processing: {object_key}")
|
|
||||||
process_recording.delay(bucket_name, object_key)
|
|
||||||
reprocessed_count += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
transcript = None
|
if transcript is None or transcript.status == "error":
|
||||||
try:
|
logger.info(f"Queueing recording for processing: {object_key}")
|
||||||
transcript = await transcripts_controller.get_by_recording_id(
|
process_recording.delay(bucket_name, object_key)
|
||||||
session, recording.id
|
reprocessed_count += 1
|
||||||
)
|
|
||||||
except ValidationError:
|
|
||||||
await transcripts_controller.remove_by_recording_id(
|
|
||||||
session, recording.id
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"Removed invalid transcript for recording: {recording.id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if transcript is None or transcript.status == "error":
|
|
||||||
logger.info(f"Queueing recording for processing: {object_key}")
|
|
||||||
process_recording.delay(bucket_name, object_key)
|
|
||||||
reprocessed_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking S3 bucket: {str(e)}")
|
logger.error(f"Error checking S3 bucket: {str(e)}")
|
||||||
|
|||||||
41
server/reflector/worker/session_decorator.py
Normal file
41
server/reflector/worker/session_decorator.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
Session management decorator for async worker tasks.
|
||||||
|
|
||||||
|
This decorator ensures that all worker tasks have a properly managed database session
|
||||||
|
that stays open for the entire duration of the task execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
from reflector.db import get_session_factory
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def with_session(func: F) -> F:
|
||||||
|
"""
|
||||||
|
Decorator that provides an AsyncSession as the first argument to the decorated function.
|
||||||
|
|
||||||
|
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
|
||||||
|
proper session management throughout the task execution.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
@with_session
|
||||||
|
async def my_task(session: AsyncSession, arg1: str, arg2: int):
|
||||||
|
# session is automatically provided and managed
|
||||||
|
result = await some_controller.get_by_id(session, arg1)
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
|
async with session.begin():
|
||||||
|
# Pass session as first argument to the decorated function
|
||||||
|
return await func(session, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
Reference in New Issue
Block a user