refactor: complete session management cleanup and test improvements

- Remove redundant session management from pipelines
- Simplify session handling in db transcript operations
- Add comprehensive test fixtures for session management
- Clean up unused imports and decorators
This commit is contained in:
2025-09-25 12:43:37 -06:00
parent d86dc59bf2
commit 9b3da4b2c8
5 changed files with 130 additions and 127 deletions

View File

@@ -2,7 +2,6 @@ import enum
import json import json
import os import os
import shutil import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
@@ -481,7 +480,12 @@ class TranscriptController:
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates. # TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged # using mutate=True is discouraged
async def update( async def update(
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False self,
session: AsyncSession,
transcript: Transcript,
values: dict,
commit=True,
mutate=False,
) -> Transcript: ) -> Transcript:
""" """
Update a transcript fields with key/values in values. Update a transcript fields with key/values in values.
@@ -495,6 +499,7 @@ class TranscriptController:
.values(**values) .values(**values)
) )
await session.execute(query) await session.execute(query)
if commit:
await session.commit() await session.commit()
if mutate: if mutate:
for key, value in values.items(): for key, value in values.items():
@@ -585,29 +590,21 @@ class TranscriptController:
await session.execute(query) await session.execute(query)
await session.commit() await session.commit()
@asynccontextmanager
async def transaction(self, session: AsyncSession):
"""
A context manager for database transaction
"""
if session.in_transaction():
yield
else:
async with session.begin():
yield
async def append_event( async def append_event(
self, self,
session: AsyncSession, session: AsyncSession,
transcript: Transcript, transcript: Transcript,
event: str, event: str,
data: Any, data: Any,
commit=True,
) -> TranscriptEvent: ) -> TranscriptEvent:
""" """
Append an event to a transcript Append an event to a transcript
""" """
resp = transcript.add_event(event=event, data=data) resp = transcript.add_event(event=event, data=data)
await self.update(session, transcript, {"events": transcript.events_dump()}) await self.update(
session, transcript, {"events": transcript.events_dump()}, commit=commit
)
return resp return resp
async def upsert_topic( async def upsert_topic(
@@ -702,7 +699,6 @@ class TranscriptController:
Will add an event STATUS + update the status field of transcript Will add an event STATUS + update the status field of transcript
""" """
async with self.transaction(session):
transcript = await self.get_by_id(session, transcript_id) transcript = await self.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
@@ -713,8 +709,15 @@ class TranscriptController:
transcript=transcript, transcript=transcript,
event="STATUS", event="STATUS",
data=StrValue(value=status), data=StrValue(value=status),
commit=False,
) )
await self.update(session, transcript, {"status": status}) await self.update(
session,
transcript,
{"status": status},
commit=False,
)
await session.commit()
return resp return resp

View File

@@ -431,6 +431,7 @@ async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: st
await pipeline.process(session, audio_file) await pipeline.process(session, audio_file)
except Exception: except Exception:
logger.error("Error while processing the file", exc_info=True)
try: try:
await pipeline.set_status(session, transcript_id, "error") await pipeline.set_status(session, transcript_id, "error")
except: except:

View File

@@ -12,7 +12,6 @@ It is directly linked to our data model.
""" """
import asyncio import asyncio
import functools
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Generic from typing import Generic
@@ -90,31 +89,6 @@ def broadcast_to_sockets(func):
return wrapper return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
@functools.wraps(func)
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
async with get_session_context() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
try:
result = await func(transcript=transcript, logger=tlogger, **kwargs)
return result
except Exception as exc:
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
raise
return wrapper
class StrValue(BaseModel): class StrValue(BaseModel):
value: str value: str
@@ -165,14 +139,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
yield yield
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def locked_session(self):
async with self.lock_transaction(): async with self.lock_transaction():
async with get_session_context() as session: async with get_session_context() as session:
print(">>> SESSION USING", session, session.in_transaction())
if session.in_transaction():
yield session
else:
async with session.begin():
yield session yield session
@broadcast_to_sockets @broadcast_to_sockets
@@ -211,7 +180,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_transcript(self, data): async def on_transcript(self, data):
async with self.transaction() as session: async with self.locked_session() as session:
transcript = await self.get_transcript(session) transcript = await self.get_transcript(session)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session, session,
@@ -231,7 +200,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
) )
if isinstance(data, TitleSummaryWithIdProcessorType): if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id topic.id = data.id
async with self.transaction() as session: async with self.locked_session() as session:
transcript = await self.get_transcript(session) transcript = await self.get_transcript(session)
await transcripts_controller.upsert_topic(session, transcript, topic) await transcripts_controller.upsert_topic(session, transcript, topic)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
@@ -244,7 +213,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_title(self, data): async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title) final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction() as session: async with self.locked_session() as session:
transcript = await self.get_transcript(session) transcript = await self.get_transcript(session)
if not transcript.title: if not transcript.title:
await transcripts_controller.update( await transcripts_controller.update(
@@ -264,7 +233,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_long_summary(self, data): async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction() as session: async with self.locked_session() as session:
transcript = await self.get_transcript(session) transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session, session,
@@ -285,7 +254,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
final_short_summary = TranscriptFinalShortSummary( final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary short_summary=data.short_summary
) )
async with self.transaction() as session: async with self.locked_session() as session:
transcript = await self.get_transcript(session) transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session, session,
@@ -303,7 +272,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_duration(self, data): async def on_duration(self, data):
async with self.transaction() as session: async with self.locked_session() as session:
duration = TranscriptDuration(duration=data) duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript(session) transcript = await self.get_transcript(session)
@@ -320,7 +289,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_waveform(self, data): async def on_waveform(self, data):
async with self.transaction() as session: async with self.locked_session() as session:
waveform = TranscriptWaveform(waveform=data) waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript(session) transcript = await self.get_transcript(session)
@@ -483,8 +452,7 @@ class PipelineMainWaveform(PipelineMainFromTopics):
] ]
@get_transcript async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that # for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
logger.info("Starting remove upload") logger.info("Starting remove upload")
uploads = transcript.data_path.glob("upload.*") uploads = transcript.data_path.glob("upload.*")
@@ -493,16 +461,14 @@ async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
logger.info("Remove upload done") logger.info("Remove upload done")
@get_transcript async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
async def pipeline_waveform(transcript: Transcript, logger: Logger):
logger.info("Starting waveform") logger.info("Starting waveform")
runner = PipelineMainWaveform(transcript_id=transcript.id) runner = PipelineMainWaveform(transcript_id=transcript.id)
await runner.run() await runner.run()
logger.info("Waveform done") logger.info("Waveform done")
@get_transcript async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3") logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip # If the audio wav is not available, just skip
@@ -551,24 +517,21 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
logger.info("Upload mp3 done") logger.info("Upload mp3 done")
@get_transcript async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization") logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id) runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.run() await runner.run()
logger.info("Diarization done") logger.info("Diarization done")
@get_transcript async def pipeline_title(session, transcript: Transcript, logger: Logger):
async def pipeline_title(transcript: Transcript, logger: Logger):
logger.info("Starting title") logger.info("Starting title")
runner = PipelineMainTitle(transcript_id=transcript.id) runner = PipelineMainTitle(transcript_id=transcript.id)
await runner.run() await runner.run()
logger.info("Title done") logger.info("Title done")
@get_transcript async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries") logger.info("Starting summaries")
runner = PipelineMainFinalSummaries(transcript_id=transcript.id) runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.run() await runner.run()
@@ -703,18 +666,27 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
@taskiq_broker.task @taskiq_broker.task
async def task_pipeline_remove_upload(*, transcript_id: str): @with_session_and_transcript
await pipeline_remove_upload(transcript_id=transcript_id) async def task_pipeline_remove_upload(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
@taskiq_broker.task @taskiq_broker.task
async def task_pipeline_waveform(*, transcript_id: str): @with_session_and_transcript
await pipeline_waveform(transcript_id=transcript_id) async def task_pipeline_waveform(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_waveform(session, transcript=transcript, logger=logger)
@taskiq_broker.task @taskiq_broker.task
async def task_pipeline_convert_to_mp3(*, transcript_id: str): @with_session_and_transcript
await pipeline_convert_to_mp3(transcript_id=transcript_id) async def task_pipeline_convert_to_mp3(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
@taskiq_broker.task @taskiq_broker.task
@@ -726,32 +698,39 @@ async def task_pipeline_upload_mp3(
@taskiq_broker.task @taskiq_broker.task
async def task_pipeline_diarization(*, transcript_id: str): @with_session_and_transcript
await pipeline_diarization(transcript_id=transcript_id) async def task_pipeline_diarization(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
@taskiq_broker.task await pipeline_diarization(session, transcript=transcript, logger=logger)
async def task_pipeline_title(*, transcript_id: str):
await pipeline_title(transcript_id=transcript_id)
@taskiq_broker.task
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
@taskiq_broker.task @taskiq_broker.task
@with_session_and_transcript @with_session_and_transcript
async def task_cleanup_consent( async def task_pipeline_title(
session, *, transcript: Transcript, logger: Logger, transcript_id: str session, *, transcript: Transcript, logger: Logger, transcript_id: str
): ):
await pipeline_title(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_final_summaries(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_summaries(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@with_session_and_transcript
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
await cleanup_consent(session, transcript=transcript, logger=logger) await cleanup_consent(session, transcript=transcript, logger=logger)
@taskiq_broker.task @taskiq_broker.task
@with_session_and_transcript @with_session_and_transcript
async def task_pipeline_post_to_zulip( async def task_pipeline_post_to_zulip(
session, *, transcript: Transcript, logger: Logger, transcript_id: str session, *, transcript: Transcript, logger: Logger
): ):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger) await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
@@ -759,7 +738,7 @@ async def task_pipeline_post_to_zulip(
@taskiq_broker.task @taskiq_broker.task
@with_session_and_transcript @with_session_and_transcript
async def task_cleanup_consent_taskiq( async def task_cleanup_consent_taskiq(
session, *, transcript: Transcript, logger: Logger, transcript_id: str session, *, transcript: Transcript, logger: Logger
): ):
await cleanup_consent(session, transcript=transcript, logger=logger) await cleanup_consent(session, transcript=transcript, logger=logger)
@@ -767,7 +746,7 @@ async def task_cleanup_consent_taskiq(
@taskiq_broker.task @taskiq_broker.task
@with_session_and_transcript @with_session_and_transcript
async def task_pipeline_post_to_zulip_taskiq( async def task_pipeline_post_to_zulip_taskiq(
session, *, transcript: Transcript, logger: Logger, transcript_id: str session, *, transcript: Transcript, logger: Logger
): ):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger) await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
@@ -793,11 +772,9 @@ async def task_pipeline_post_sequential(*, transcript_id: str):
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id) await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
@get_transcript async def pipeline_process(session, transcript: Transcript, logger: Logger):
async def pipeline_process(transcript: Transcript, logger: Logger):
try: try:
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
async with get_session_context() as session:
await transcripts_controller.download_mp3_from_storage(transcript) await transcripts_controller.download_mp3_from_storage(transcript)
transcript.audio_waveform_filename.unlink(missing_ok=True) transcript.audio_waveform_filename.unlink(missing_ok=True)
await transcripts_controller.update( await transcripts_controller.update(
@@ -838,7 +815,6 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
except Exception as exc: except Exception as exc:
logger.error("Pipeline error", exc_info=exc) logger.error("Pipeline error", exc_info=exc)
async with get_session_context() as session:
await transcripts_controller.update( await transcripts_controller.update(
session, session,
transcript, transcript,
@@ -852,5 +828,8 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
@taskiq_broker.task @taskiq_broker.task
async def task_pipeline_process(*, transcript_id: str): @with_session_and_transcript
return await pipeline_process(transcript_id=transcript_id) async def task_pipeline_process(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
return await pipeline_process(session, transcript=transcript, logger=logger)

View File

@@ -71,16 +71,13 @@ def with_session_and_transcript(func: F) -> F:
) )
async with get_session_context() as session: async with get_session_context() as session:
# Fetch the transcript
transcript = await transcripts_controller.get_by_id(session, transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
# Create enhanced logger
tlogger = logger.bind(transcript_id=transcript.id) tlogger = logger.bind(transcript_id=transcript.id)
try: try:
# Pass session, transcript, and logger to the decorated function
return await func( return await func(
session, transcript=transcript, logger=tlogger, *args, **kwargs session, transcript=transcript, logger=tlogger, *args, **kwargs
) )

View File

@@ -377,6 +377,29 @@ async def dummy_storage():
# yield # yield
# @pytest.fixture()
# async def db_session(sqla_engine):
# """
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
# after the test completes.
# """
# from sqlalchemy.ext.asyncio import AsyncSession
# from sqlalchemy.orm import sessionmaker
# connection = await sqla_engine.connect()
# trans = await connection.begin()
# Session = sessionmaker(connection, expire_on_commit=False, class_=AsyncSession)
# session = Session()
# try:
# yield session
# finally:
# await session.close()
# await trans.rollback()
# await connection.close()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def ensure_db_session_in_app(db_session): async def ensure_db_session_in_app(db_session):
async def mock_get_session(): async def mock_get_session():