refactor: complete session management cleanup and test improvements

- Remove redundant session management from pipelines
- Simplify session handling in db transcript operations
- Add comprehensive test fixtures for session management
- Clean up unused imports and decorators
This commit is contained in:
2025-09-25 12:43:37 -06:00
parent d86dc59bf2
commit 9b3da4b2c8
5 changed files with 130 additions and 127 deletions

View File

@@ -2,7 +2,6 @@ import enum
import json
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
@@ -481,7 +480,12 @@ class TranscriptController:
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False
self,
session: AsyncSession,
transcript: Transcript,
values: dict,
commit=True,
mutate=False,
) -> Transcript:
"""
Update a transcript fields with key/values in values.
@@ -495,7 +499,8 @@ class TranscriptController:
.values(**values)
)
await session.execute(query)
await session.commit()
if commit:
await session.commit()
if mutate:
for key, value in values.items():
setattr(transcript, key, value)
@@ -585,29 +590,21 @@ class TranscriptController:
await session.execute(query)
await session.commit()
@asynccontextmanager
async def transaction(self, session: AsyncSession):
"""
A context manager for database transaction
"""
if session.in_transaction():
yield
else:
async with session.begin():
yield
async def append_event(
self,
session: AsyncSession,
transcript: Transcript,
event: str,
data: Any,
commit=True,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(session, transcript, {"events": transcript.events_dump()})
await self.update(
session, transcript, {"events": transcript.events_dump()}, commit=commit
)
return resp
async def upsert_topic(
@@ -702,19 +699,25 @@ class TranscriptController:
Will add an event STATUS + update the status field of transcript
"""
async with self.transaction(session):
transcript = await self.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
if transcript.status == status:
return
resp = await self.append_event(
session,
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await self.update(session, transcript, {"status": status})
transcript = await self.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
if transcript.status == status:
return
resp = await self.append_event(
session,
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
commit=False,
)
await self.update(
session,
transcript,
{"status": status},
commit=False,
)
await session.commit()
return resp

View File

@@ -431,6 +431,7 @@ async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: st
await pipeline.process(session, audio_file)
except Exception:
logger.error("Error while processing the file", exc_info=True)
try:
await pipeline.set_status(session, transcript_id, "error")
except:

View File

@@ -12,7 +12,6 @@ It is directly linked to our data model.
"""
import asyncio
import functools
from contextlib import asynccontextmanager
from typing import Generic
@@ -90,31 +89,6 @@ def broadcast_to_sockets(func):
return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
@functools.wraps(func)
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
async with get_session_context() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
try:
result = await func(transcript=transcript, logger=tlogger, **kwargs)
return result
except Exception as exc:
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
raise
return wrapper
class StrValue(BaseModel):
value: str
@@ -165,15 +139,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
yield
@asynccontextmanager
async def transaction(self):
async def locked_session(self):
async with self.lock_transaction():
async with get_session_context() as session:
print(">>> SESSION USING", session, session.in_transaction())
if session.in_transaction():
yield session
else:
async with session.begin():
yield session
yield session
@broadcast_to_sockets
async def on_status(self, status):
@@ -211,7 +180,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_transcript(self, data):
async with self.transaction() as session:
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
return await transcripts_controller.append_event(
session,
@@ -231,7 +200,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
)
if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id
async with self.transaction() as session:
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
await transcripts_controller.upsert_topic(session, transcript, topic)
return await transcripts_controller.append_event(
@@ -244,7 +213,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction() as session:
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
if not transcript.title:
await transcripts_controller.update(
@@ -264,7 +233,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction() as session:
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
await transcripts_controller.update(
session,
@@ -285,7 +254,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with self.transaction() as session:
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
await transcripts_controller.update(
session,
@@ -303,7 +272,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction() as session:
async with self.locked_session() as session:
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript(session)
@@ -320,7 +289,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_waveform(self, data):
async with self.transaction() as session:
async with self.locked_session() as session:
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript(session)
@@ -483,8 +452,7 @@ class PipelineMainWaveform(PipelineMainFromTopics):
]
@get_transcript
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
logger.info("Starting remove upload")
uploads = transcript.data_path.glob("upload.*")
@@ -493,16 +461,14 @@ async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
logger.info("Remove upload done")
@get_transcript
async def pipeline_waveform(transcript: Transcript, logger: Logger):
async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
logger.info("Starting waveform")
runner = PipelineMainWaveform(transcript_id=transcript.id)
await runner.run()
logger.info("Waveform done")
@get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
@@ -551,24 +517,21 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
logger.info("Upload mp3 done")
@get_transcript
async def pipeline_diarization(transcript: Transcript, logger: Logger):
async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.run()
logger.info("Diarization done")
@get_transcript
async def pipeline_title(transcript: Transcript, logger: Logger):
async def pipeline_title(session, transcript: Transcript, logger: Logger):
logger.info("Starting title")
runner = PipelineMainTitle(transcript_id=transcript.id)
await runner.run()
logger.info("Title done")
@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger):
async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.run()
@@ -703,18 +666,27 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
@taskiq_broker.task
async def task_pipeline_remove_upload(*, transcript_id: str):
await pipeline_remove_upload(transcript_id=transcript_id)
@with_session_and_transcript
async def task_pipeline_remove_upload(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
@taskiq_broker.task
async def task_pipeline_waveform(*, transcript_id: str):
await pipeline_waveform(transcript_id=transcript_id)
@with_session_and_transcript
async def task_pipeline_waveform(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_waveform(session, transcript=transcript, logger=logger)
@taskiq_broker.task
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
await pipeline_convert_to_mp3(transcript_id=transcript_id)
@with_session_and_transcript
async def task_pipeline_convert_to_mp3(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@@ -726,32 +698,39 @@ async def task_pipeline_upload_mp3(
@taskiq_broker.task
async def task_pipeline_diarization(*, transcript_id: str):
await pipeline_diarization(transcript_id=transcript_id)
@taskiq_broker.task
async def task_pipeline_title(*, transcript_id: str):
await pipeline_title(transcript_id=transcript_id)
@taskiq_broker.task
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
@with_session_and_transcript
async def task_pipeline_diarization(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_diarization(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@with_session_and_transcript
async def task_cleanup_consent(
async def task_pipeline_title(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_title(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_final_summaries(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_summaries(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@with_session_and_transcript
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
await cleanup_consent(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_post_to_zulip(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
session, *, transcript: Transcript, logger: Logger
):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
@@ -759,7 +738,7 @@ async def task_pipeline_post_to_zulip(
@taskiq_broker.task
@with_session_and_transcript
async def task_cleanup_consent_taskiq(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
session, *, transcript: Transcript, logger: Logger
):
await cleanup_consent(session, transcript=transcript, logger=logger)
@@ -767,7 +746,7 @@ async def task_cleanup_consent_taskiq(
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_post_to_zulip_taskiq(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
session, *, transcript: Transcript, logger: Logger
):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
@@ -793,20 +772,18 @@ async def task_pipeline_post_sequential(*, transcript_id: str):
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
@get_transcript
async def pipeline_process(transcript: Transcript, logger: Logger):
async def pipeline_process(session, transcript: Transcript, logger: Logger):
try:
if transcript.audio_location == "storage":
async with get_session_context() as session:
await transcripts_controller.download_mp3_from_storage(transcript)
transcript.audio_waveform_filename.unlink(missing_ok=True)
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
await transcripts_controller.download_mp3_from_storage(transcript)
transcript.audio_waveform_filename.unlink(missing_ok=True)
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
# open audio
audio_filename = next(transcript.data_path.glob("upload.*"), None)
@@ -838,19 +815,21 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
except Exception as exc:
logger.error("Pipeline error", exc_info=exc)
async with get_session_context() as session:
await transcripts_controller.update(
session,
transcript,
{
"status": "error",
},
)
await transcripts_controller.update(
session,
transcript,
{
"status": "error",
},
)
raise
logger.info("Pipeline ended")
@taskiq_broker.task
async def task_pipeline_process(*, transcript_id: str):
return await pipeline_process(transcript_id=transcript_id)
@with_session_and_transcript
async def task_pipeline_process(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
return await pipeline_process(session, transcript=transcript, logger=logger)

View File

@@ -71,16 +71,13 @@ def with_session_and_transcript(func: F) -> F:
)
async with get_session_context() as session:
# Fetch the transcript
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
# Create enhanced logger
tlogger = logger.bind(transcript_id=transcript.id)
try:
# Pass session, transcript, and logger to the decorated function
return await func(
session, transcript=transcript, logger=tlogger, *args, **kwargs
)

View File

@@ -377,6 +377,29 @@ async def dummy_storage():
# yield
# @pytest.fixture()
# async def db_session(sqla_engine):
# """
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
# after the test completes.
# """
# from sqlalchemy.ext.asyncio import AsyncSession
# from sqlalchemy.orm import sessionmaker
# connection = await sqla_engine.connect()
# trans = await connection.begin()
# Session = sessionmaker(connection, expire_on_commit=False, class_=AsyncSession)
# session = Session()
# try:
# yield session
# finally:
# await session.close()
# await trans.rollback()
# await connection.close()
@pytest.fixture(autouse=True)
async def ensure_db_session_in_app(db_session):
async def mock_get_session():