From 9b3da4b2c87a427f10ec2a141bee52b010e398f6 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 25 Sep 2025 12:43:37 -0600 Subject: [PATCH] refactor: complete session management cleanup and test improvements - Remove redundant session management from pipelines - Simplify session handling in db transcript operations - Add comprehensive test fixtures for session management - Clean up unused imports and decorators --- server/reflector/db/transcripts.py | 59 +++--- .../reflector/pipelines/main_file_pipeline.py | 1 + .../reflector/pipelines/main_live_pipeline.py | 171 ++++++++---------- server/reflector/worker/session_decorator.py | 3 - server/tests/conftest.py | 23 +++ 5 files changed, 130 insertions(+), 127 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index c3aaa8fe..04ab3672 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -2,7 +2,6 @@ import enum import json import os import shutil -from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any, Literal @@ -481,7 +480,12 @@ class TranscriptController: # TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates. # using mutate=True is discouraged async def update( - self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False + self, + session: AsyncSession, + transcript: Transcript, + values: dict, + commit=True, + mutate=False, ) -> Transcript: """ Update a transcript fields with key/values in values. @@ -495,7 +499,8 @@ class TranscriptController: .values(**values) ) await session.execute(query) - await session.commit() + if commit: + await session.commit() if mutate: for key, value in values.items(): setattr(transcript, key, value) @@ -585,29 +590,21 @@ class TranscriptController: await session.execute(query) await session.commit() - @asynccontextmanager - async def transaction(self, session: AsyncSession): - """ - A context manager for database transaction - """ - if session.in_transaction(): - yield - else: - async with session.begin(): - yield - async def append_event( self, session: AsyncSession, transcript: Transcript, event: str, data: Any, + commit=True, ) -> TranscriptEvent: """ Append an event to a transcript """ resp = transcript.add_event(event=event, data=data) - await self.update(session, transcript, {"events": transcript.events_dump()}) + await self.update( + session, transcript, {"events": transcript.events_dump()}, commit=commit + ) return resp async def upsert_topic( @@ -702,19 +699,25 @@ class TranscriptController: Will add an event STATUS + update the status field of transcript """ - async with self.transaction(session): - transcript = await self.get_by_id(session, transcript_id) - if not transcript: - raise Exception(f"Transcript {transcript_id} not found") - if transcript.status == status: - return - resp = await self.append_event( - session, - transcript=transcript, - event="STATUS", - data=StrValue(value=status), - ) - await self.update(session, transcript, {"status": status}) + transcript = await self.get_by_id(session, transcript_id) + if not transcript: + raise Exception(f"Transcript {transcript_id} not found") + if transcript.status == status: + return + resp = await self.append_event( + session, + transcript=transcript, + event="STATUS", + data=StrValue(value=status), + commit=False, + ) + await self.update( + session, + transcript, + {"status": status}, + commit=False, + ) + await session.commit() return resp diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index 2d8dcc1c..21faa726 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -431,6 +431,7 @@ async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: st await pipeline.process(session, audio_file) except Exception: + logger.error("Error while processing the file", exc_info=True) try: await pipeline.set_status(session, transcript_id, "error") except: diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 884fb457..5dde4b77 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -12,7 +12,6 @@ It is directly linked to our data model. """ import asyncio -import functools from contextlib import asynccontextmanager from typing import Generic @@ -90,31 +89,6 @@ def broadcast_to_sockets(func): return wrapper -def get_transcript(func): - """ - Decorator to fetch the transcript from the database from the first argument - """ - - @functools.wraps(func) - async def wrapper(**kwargs): - transcript_id = kwargs.pop("transcript_id") - async with get_session_context() as session: - transcript = await transcripts_controller.get_by_id(session, transcript_id) - if not transcript: - raise Exception(f"Transcript {transcript_id} not found") - - tlogger = logger.bind(transcript_id=transcript.id) - - try: - result = await func(transcript=transcript, logger=tlogger, **kwargs) - return result - except Exception as exc: - tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc) - raise - - return wrapper - - class StrValue(BaseModel): value: str @@ -165,15 +139,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] yield @asynccontextmanager - async def transaction(self): + async def locked_session(self): async with self.lock_transaction(): async with get_session_context() as session: - print(">>> SESSION USING", session, session.in_transaction()) - if session.in_transaction(): - yield session - else: - async with session.begin(): - yield session + yield session @broadcast_to_sockets async def on_status(self, status): @@ -211,7 +180,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] @broadcast_to_sockets async def on_transcript(self, data): - async with self.transaction() as session: + async with self.locked_session() as session: transcript = await self.get_transcript(session) return await transcripts_controller.append_event( session, @@ -231,7 +200,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] ) if isinstance(data, TitleSummaryWithIdProcessorType): topic.id = data.id - async with self.transaction() as session: + async with self.locked_session() as session: transcript = await self.get_transcript(session) await transcripts_controller.upsert_topic(session, transcript, topic) return await transcripts_controller.append_event( @@ -244,7 +213,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] @broadcast_to_sockets async def on_title(self, data): final_title = TranscriptFinalTitle(title=data.title) - async with self.transaction() as session: + async with self.locked_session() as session: transcript = await self.get_transcript(session) if not transcript.title: await transcripts_controller.update( @@ -264,7 +233,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] @broadcast_to_sockets async def on_long_summary(self, data): final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) - async with self.transaction() as session: + async with self.locked_session() as session: transcript = await self.get_transcript(session) await transcripts_controller.update( session, @@ -285,7 +254,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] final_short_summary = TranscriptFinalShortSummary( short_summary=data.short_summary ) - async with self.transaction() as session: + async with self.locked_session() as session: transcript = await self.get_transcript(session) await transcripts_controller.update( session, @@ -303,7 +272,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] @broadcast_to_sockets async def on_duration(self, data): - async with self.transaction() as session: + async with self.locked_session() as session: duration = TranscriptDuration(duration=data) transcript = await self.get_transcript(session) @@ -320,7 +289,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage] @broadcast_to_sockets async def on_waveform(self, data): - async with self.transaction() as session: + async with self.locked_session() as session: waveform = TranscriptWaveform(waveform=data) transcript = await self.get_transcript(session) @@ -483,8 +452,7 @@ class PipelineMainWaveform(PipelineMainFromTopics): ] -@get_transcript -async def pipeline_remove_upload(transcript: Transcript, logger: Logger): +async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger): # for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that logger.info("Starting remove upload") uploads = transcript.data_path.glob("upload.*") @@ -493,16 +461,14 @@ async def pipeline_remove_upload(transcript: Transcript, logger: Logger): logger.info("Remove upload done") -@get_transcript -async def pipeline_waveform(transcript: Transcript, logger: Logger): +async def pipeline_waveform(session, transcript: Transcript, logger: Logger): logger.info("Starting waveform") runner = PipelineMainWaveform(transcript_id=transcript.id) await runner.run() logger.info("Waveform done") -@get_transcript -async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): +async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger): logger.info("Starting convert to mp3") # If the audio wav is not available, just skip @@ -551,24 +517,21 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger): logger.info("Upload mp3 done") -@get_transcript -async def pipeline_diarization(transcript: Transcript, logger: Logger): +async def pipeline_diarization(session, transcript: Transcript, logger: Logger): logger.info("Starting diarization") runner = PipelineMainDiarization(transcript_id=transcript.id) await runner.run() logger.info("Diarization done") -@get_transcript -async def pipeline_title(transcript: Transcript, logger: Logger): +async def pipeline_title(session, transcript: Transcript, logger: Logger): logger.info("Starting title") runner = PipelineMainTitle(transcript_id=transcript.id) await runner.run() logger.info("Title done") -@get_transcript -async def pipeline_summaries(transcript: Transcript, logger: Logger): +async def pipeline_summaries(session, transcript: Transcript, logger: Logger): logger.info("Starting summaries") runner = PipelineMainFinalSummaries(transcript_id=transcript.id) await runner.run() @@ -703,18 +666,27 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger @taskiq_broker.task -async def task_pipeline_remove_upload(*, transcript_id: str): - await pipeline_remove_upload(transcript_id=transcript_id) +@with_session_and_transcript +async def task_pipeline_remove_upload( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await pipeline_remove_upload(session, transcript=transcript, logger=logger) @taskiq_broker.task -async def task_pipeline_waveform(*, transcript_id: str): - await pipeline_waveform(transcript_id=transcript_id) +@with_session_and_transcript +async def task_pipeline_waveform( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await pipeline_waveform(session, transcript=transcript, logger=logger) @taskiq_broker.task -async def task_pipeline_convert_to_mp3(*, transcript_id: str): - await pipeline_convert_to_mp3(transcript_id=transcript_id) +@with_session_and_transcript +async def task_pipeline_convert_to_mp3( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger) @taskiq_broker.task @@ -726,32 +698,39 @@ async def task_pipeline_upload_mp3( @taskiq_broker.task -async def task_pipeline_diarization(*, transcript_id: str): - await pipeline_diarization(transcript_id=transcript_id) - - -@taskiq_broker.task -async def task_pipeline_title(*, transcript_id: str): - await pipeline_title(transcript_id=transcript_id) - - -@taskiq_broker.task -async def task_pipeline_final_summaries(*, transcript_id: str): - await pipeline_summaries(transcript_id=transcript_id) +@with_session_and_transcript +async def task_pipeline_diarization( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await pipeline_diarization(session, transcript=transcript, logger=logger) @taskiq_broker.task @with_session_and_transcript -async def task_cleanup_consent( +async def task_pipeline_title( session, *, transcript: Transcript, logger: Logger, transcript_id: str ): + await pipeline_title(session, transcript=transcript, logger=logger) + + +@taskiq_broker.task +@with_session_and_transcript +async def task_pipeline_final_summaries( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + await pipeline_summaries(session, transcript=transcript, logger=logger) + + +@taskiq_broker.task +@with_session_and_transcript +async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger): await cleanup_consent(session, transcript=transcript, logger=logger) @taskiq_broker.task @with_session_and_transcript async def task_pipeline_post_to_zulip( - session, *, transcript: Transcript, logger: Logger, transcript_id: str + session, *, transcript: Transcript, logger: Logger ): await pipeline_post_to_zulip(session, transcript=transcript, logger=logger) @@ -759,7 +738,7 @@ async def task_pipeline_post_to_zulip( @taskiq_broker.task @with_session_and_transcript async def task_cleanup_consent_taskiq( - session, *, transcript: Transcript, logger: Logger, transcript_id: str + session, *, transcript: Transcript, logger: Logger ): await cleanup_consent(session, transcript=transcript, logger=logger) @@ -767,7 +746,7 @@ async def task_cleanup_consent_taskiq( @taskiq_broker.task @with_session_and_transcript async def task_pipeline_post_to_zulip_taskiq( - session, *, transcript: Transcript, logger: Logger, transcript_id: str + session, *, transcript: Transcript, logger: Logger ): await pipeline_post_to_zulip(session, transcript=transcript, logger=logger) @@ -793,20 +772,18 @@ async def task_pipeline_post_sequential(*, transcript_id: str): await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id) -@get_transcript -async def pipeline_process(transcript: Transcript, logger: Logger): +async def pipeline_process(session, transcript: Transcript, logger: Logger): try: if transcript.audio_location == "storage": - async with get_session_context() as session: - await transcripts_controller.download_mp3_from_storage(transcript) - transcript.audio_waveform_filename.unlink(missing_ok=True) - await transcripts_controller.update( - session, - transcript, - { - "topics": [], - }, - ) + await transcripts_controller.download_mp3_from_storage(transcript) + transcript.audio_waveform_filename.unlink(missing_ok=True) + await transcripts_controller.update( + session, + transcript, + { + "topics": [], + }, + ) # open audio audio_filename = next(transcript.data_path.glob("upload.*"), None) @@ -838,19 +815,21 @@ async def pipeline_process(transcript: Transcript, logger: Logger): except Exception as exc: logger.error("Pipeline error", exc_info=exc) - async with get_session_context() as session: - await transcripts_controller.update( - session, - transcript, - { - "status": "error", - }, - ) + await transcripts_controller.update( + session, + transcript, + { + "status": "error", + }, + ) raise logger.info("Pipeline ended") @taskiq_broker.task -async def task_pipeline_process(*, transcript_id: str): - return await pipeline_process(transcript_id=transcript_id) +@with_session_and_transcript +async def task_pipeline_process( + session, *, transcript: Transcript, logger: Logger, transcript_id: str +): + return await pipeline_process(session, transcript=transcript, logger=logger) diff --git a/server/reflector/worker/session_decorator.py b/server/reflector/worker/session_decorator.py index be1a2f22..a2a4c32e 100644 --- a/server/reflector/worker/session_decorator.py +++ b/server/reflector/worker/session_decorator.py @@ -71,16 +71,13 @@ def with_session_and_transcript(func: F) -> F: ) async with get_session_context() as session: - # Fetch the transcript transcript = await transcripts_controller.get_by_id(session, transcript_id) if not transcript: raise Exception(f"Transcript {transcript_id} not found") - # Create enhanced logger tlogger = logger.bind(transcript_id=transcript.id) try: - # Pass session, transcript, and logger to the decorated function return await func( session, transcript=transcript, logger=tlogger, *args, **kwargs ) diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 93ebf2e1..3b0b352a 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -377,6 +377,29 @@ async def dummy_storage(): # yield +# @pytest.fixture() +# async def db_session(sqla_engine): +# """ +# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it +# after the test completes. +# """ +# from sqlalchemy.ext.asyncio import AsyncSession +# from sqlalchemy.orm import sessionmaker + +# connection = await sqla_engine.connect() +# trans = await connection.begin() + +# Session = sessionmaker(connection, expire_on_commit=False, class_=AsyncSession) +# session = Session() + +# try: +# yield session +# finally: +# await session.close() +# await trans.rollback() +# await connection.close() + + @pytest.fixture(autouse=True) async def ensure_db_session_in_app(db_session): async def mock_get_session():