mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
refactor: complete session management cleanup and test improvements
- Remove redundant session management from pipelines - Simplify session handling in db transcript operations - Add comprehensive test fixtures for session management - Clean up unused imports and decorators
This commit is contained in:
@@ -2,7 +2,6 @@ import enum
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
@@ -481,7 +480,12 @@ class TranscriptController:
|
||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||
# using mutate=True is discouraged
|
||||
async def update(
|
||||
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
values: dict,
|
||||
commit=True,
|
||||
mutate=False,
|
||||
) -> Transcript:
|
||||
"""
|
||||
Update a transcript fields with key/values in values.
|
||||
@@ -495,7 +499,8 @@ class TranscriptController:
|
||||
.values(**values)
|
||||
)
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
if commit:
|
||||
await session.commit()
|
||||
if mutate:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
@@ -585,29 +590,21 @@ class TranscriptController:
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self, session: AsyncSession):
|
||||
"""
|
||||
A context manager for database transaction
|
||||
"""
|
||||
if session.in_transaction():
|
||||
yield
|
||||
else:
|
||||
async with session.begin():
|
||||
yield
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transcript: Transcript,
|
||||
event: str,
|
||||
data: Any,
|
||||
commit=True,
|
||||
) -> TranscriptEvent:
|
||||
"""
|
||||
Append an event to a transcript
|
||||
"""
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await self.update(session, transcript, {"events": transcript.events_dump()})
|
||||
await self.update(
|
||||
session, transcript, {"events": transcript.events_dump()}, commit=commit
|
||||
)
|
||||
return resp
|
||||
|
||||
async def upsert_topic(
|
||||
@@ -702,19 +699,25 @@ class TranscriptController:
|
||||
|
||||
Will add an event STATUS + update the status field of transcript
|
||||
"""
|
||||
async with self.transaction(session):
|
||||
transcript = await self.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
if transcript.status == status:
|
||||
return
|
||||
resp = await self.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
await self.update(session, transcript, {"status": status})
|
||||
transcript = await self.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
if transcript.status == status:
|
||||
return
|
||||
resp = await self.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="STATUS",
|
||||
data=StrValue(value=status),
|
||||
commit=False,
|
||||
)
|
||||
await self.update(
|
||||
session,
|
||||
transcript,
|
||||
{"status": status},
|
||||
commit=False,
|
||||
)
|
||||
await session.commit()
|
||||
return resp
|
||||
|
||||
|
||||
|
||||
@@ -431,6 +431,7 @@ async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: st
|
||||
await pipeline.process(session, audio_file)
|
||||
|
||||
except Exception:
|
||||
logger.error("Error while processing the file", exc_info=True)
|
||||
try:
|
||||
await pipeline.set_status(session, transcript_id, "error")
|
||||
except:
|
||||
|
||||
@@ -12,7 +12,6 @@ It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Generic
|
||||
|
||||
@@ -90,31 +89,6 @@ def broadcast_to_sockets(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_transcript(func):
|
||||
"""
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
async with get_session_context() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
|
||||
try:
|
||||
result = await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
return result
|
||||
except Exception as exc:
|
||||
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
@@ -165,15 +139,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async def locked_session(self):
|
||||
async with self.lock_transaction():
|
||||
async with get_session_context() as session:
|
||||
print(">>> SESSION USING", session, session.in_transaction())
|
||||
if session.in_transaction():
|
||||
yield session
|
||||
else:
|
||||
async with session.begin():
|
||||
yield session
|
||||
yield session
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_status(self, status):
|
||||
@@ -211,7 +180,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_transcript(self, data):
|
||||
async with self.transaction() as session:
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
@@ -231,7 +200,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
)
|
||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||
topic.id = data.id
|
||||
async with self.transaction() as session:
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
||||
return await transcripts_controller.append_event(
|
||||
@@ -244,7 +213,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_title(self, data):
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
async with self.transaction() as session:
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
@@ -264,7 +233,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_long_summary(self, data):
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
async with self.transaction() as session:
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
@@ -285,7 +254,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
async with self.transaction() as session:
|
||||
async with self.locked_session() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
@@ -303,7 +272,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction() as session:
|
||||
async with self.locked_session() as session:
|
||||
duration = TranscriptDuration(duration=data)
|
||||
|
||||
transcript = await self.get_transcript(session)
|
||||
@@ -320,7 +289,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
async with self.transaction() as session:
|
||||
async with self.locked_session() as session:
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript(session)
|
||||
@@ -483,8 +452,7 @@ class PipelineMainWaveform(PipelineMainFromTopics):
|
||||
]
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
|
||||
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
|
||||
logger.info("Starting remove upload")
|
||||
uploads = transcript.data_path.glob("upload.*")
|
||||
@@ -493,16 +461,14 @@ async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
||||
logger.info("Remove upload done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting waveform")
|
||||
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Waveform done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting convert to mp3")
|
||||
|
||||
# If the audio wav is not available, just skip
|
||||
@@ -551,24 +517,21 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting diarization")
|
||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Diarization done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_title(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_title(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting title")
|
||||
runner = PipelineMainTitle(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Title done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting summaries")
|
||||
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
@@ -703,18 +666,27 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_remove_upload(*, transcript_id: str):
|
||||
await pipeline_remove_upload(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_remove_upload(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_waveform(*, transcript_id: str):
|
||||
await pipeline_waveform(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_waveform(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_waveform(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_convert_to_mp3(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@@ -726,32 +698,39 @@ async def task_pipeline_upload_mp3(
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_diarization(*, transcript_id: str):
|
||||
await pipeline_diarization(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_title(*, transcript_id: str):
|
||||
await pipeline_title(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||
await pipeline_summaries(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_diarization(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_diarization(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_cleanup_consent(
|
||||
async def task_pipeline_title(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_title(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_final_summaries(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
await pipeline_summaries(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
|
||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_post_to_zulip(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
session, *, transcript: Transcript, logger: Logger
|
||||
):
|
||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||
|
||||
@@ -759,7 +738,7 @@ async def task_pipeline_post_to_zulip(
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_cleanup_consent_taskiq(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
session, *, transcript: Transcript, logger: Logger
|
||||
):
|
||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||
|
||||
@@ -767,7 +746,7 @@ async def task_cleanup_consent_taskiq(
|
||||
@taskiq_broker.task
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_post_to_zulip_taskiq(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
session, *, transcript: Transcript, logger: Logger
|
||||
):
|
||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||
|
||||
@@ -793,20 +772,18 @@ async def task_pipeline_post_sequential(*, transcript_id: str):
|
||||
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_process(session, transcript: Transcript, logger: Logger):
|
||||
try:
|
||||
if transcript.audio_location == "storage":
|
||||
async with get_session_context() as session:
|
||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
|
||||
# open audio
|
||||
audio_filename = next(transcript.data_path.glob("upload.*"), None)
|
||||
@@ -838,19 +815,21 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Pipeline error", exc_info=exc)
|
||||
async with get_session_context() as session:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
},
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
logger.info("Pipeline ended")
|
||||
|
||||
|
||||
@taskiq_broker.task
|
||||
async def task_pipeline_process(*, transcript_id: str):
|
||||
return await pipeline_process(transcript_id=transcript_id)
|
||||
@with_session_and_transcript
|
||||
async def task_pipeline_process(
|
||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||
):
|
||||
return await pipeline_process(session, transcript=transcript, logger=logger)
|
||||
|
||||
@@ -71,16 +71,13 @@ def with_session_and_transcript(func: F) -> F:
|
||||
)
|
||||
|
||||
async with get_session_context() as session:
|
||||
# Fetch the transcript
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
# Create enhanced logger
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
|
||||
try:
|
||||
# Pass session, transcript, and logger to the decorated function
|
||||
return await func(
|
||||
session, transcript=transcript, logger=tlogger, *args, **kwargs
|
||||
)
|
||||
|
||||
@@ -377,6 +377,29 @@ async def dummy_storage():
|
||||
# yield
|
||||
|
||||
|
||||
# @pytest.fixture()
|
||||
# async def db_session(sqla_engine):
|
||||
# """
|
||||
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
|
||||
# after the test completes.
|
||||
# """
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# connection = await sqla_engine.connect()
|
||||
# trans = await connection.begin()
|
||||
|
||||
# Session = sessionmaker(connection, expire_on_commit=False, class_=AsyncSession)
|
||||
# session = Session()
|
||||
|
||||
# try:
|
||||
# yield session
|
||||
# finally:
|
||||
# await session.close()
|
||||
# await trans.rollback()
|
||||
# await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def ensure_db_session_in_app(db_session):
|
||||
async def mock_get_session():
|
||||
|
||||
Reference in New Issue
Block a user