mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
refactor: complete session management cleanup and test improvements
- Remove redundant session management from pipelines - Simplify session handling in db transcript operations - Add comprehensive test fixtures for session management - Clean up unused imports and decorators
This commit is contained in:
@@ -2,7 +2,6 @@ import enum
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@@ -481,7 +480,12 @@ class TranscriptController:
|
|||||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||||
# using mutate=True is discouraged
|
# using mutate=True is discouraged
|
||||||
async def update(
|
async def update(
|
||||||
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
transcript: Transcript,
|
||||||
|
values: dict,
|
||||||
|
commit=True,
|
||||||
|
mutate=False,
|
||||||
) -> Transcript:
|
) -> Transcript:
|
||||||
"""
|
"""
|
||||||
Update a transcript fields with key/values in values.
|
Update a transcript fields with key/values in values.
|
||||||
@@ -495,6 +499,7 @@ class TranscriptController:
|
|||||||
.values(**values)
|
.values(**values)
|
||||||
)
|
)
|
||||||
await session.execute(query)
|
await session.execute(query)
|
||||||
|
if commit:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
if mutate:
|
if mutate:
|
||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
@@ -585,29 +590,21 @@ class TranscriptController:
|
|||||||
await session.execute(query)
|
await session.execute(query)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def transaction(self, session: AsyncSession):
|
|
||||||
"""
|
|
||||||
A context manager for database transaction
|
|
||||||
"""
|
|
||||||
if session.in_transaction():
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
async with session.begin():
|
|
||||||
yield
|
|
||||||
|
|
||||||
async def append_event(
|
async def append_event(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
event: str,
|
event: str,
|
||||||
data: Any,
|
data: Any,
|
||||||
|
commit=True,
|
||||||
) -> TranscriptEvent:
|
) -> TranscriptEvent:
|
||||||
"""
|
"""
|
||||||
Append an event to a transcript
|
Append an event to a transcript
|
||||||
"""
|
"""
|
||||||
resp = transcript.add_event(event=event, data=data)
|
resp = transcript.add_event(event=event, data=data)
|
||||||
await self.update(session, transcript, {"events": transcript.events_dump()})
|
await self.update(
|
||||||
|
session, transcript, {"events": transcript.events_dump()}, commit=commit
|
||||||
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def upsert_topic(
|
async def upsert_topic(
|
||||||
@@ -702,7 +699,6 @@ class TranscriptController:
|
|||||||
|
|
||||||
Will add an event STATUS + update the status field of transcript
|
Will add an event STATUS + update the status field of transcript
|
||||||
"""
|
"""
|
||||||
async with self.transaction(session):
|
|
||||||
transcript = await self.get_by_id(session, transcript_id)
|
transcript = await self.get_by_id(session, transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
@@ -713,8 +709,15 @@ class TranscriptController:
|
|||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="STATUS",
|
event="STATUS",
|
||||||
data=StrValue(value=status),
|
data=StrValue(value=status),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await self.update(session, transcript, {"status": status})
|
await self.update(
|
||||||
|
session,
|
||||||
|
transcript,
|
||||||
|
{"status": status},
|
||||||
|
commit=False,
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -431,6 +431,7 @@ async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: st
|
|||||||
await pipeline.process(session, audio_file)
|
await pipeline.process(session, audio_file)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
logger.error("Error while processing the file", exc_info=True)
|
||||||
try:
|
try:
|
||||||
await pipeline.set_status(session, transcript_id, "error")
|
await pipeline.set_status(session, transcript_id, "error")
|
||||||
except:
|
except:
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ It is directly linked to our data model.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Generic
|
from typing import Generic
|
||||||
|
|
||||||
@@ -90,31 +89,6 @@ def broadcast_to_sockets(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_transcript(func):
|
|
||||||
"""
|
|
||||||
Decorator to fetch the transcript from the database from the first argument
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(**kwargs):
|
|
||||||
transcript_id = kwargs.pop("transcript_id")
|
|
||||||
async with get_session_context() as session:
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
if not transcript:
|
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
|
||||||
|
|
||||||
tlogger = logger.bind(transcript_id=transcript.id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await func(transcript=transcript, logger=tlogger, **kwargs)
|
|
||||||
return result
|
|
||||||
except Exception as exc:
|
|
||||||
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class StrValue(BaseModel):
|
class StrValue(BaseModel):
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
@@ -165,14 +139,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def transaction(self):
|
async def locked_session(self):
|
||||||
async with self.lock_transaction():
|
async with self.lock_transaction():
|
||||||
async with get_session_context() as session:
|
async with get_session_context() as session:
|
||||||
print(">>> SESSION USING", session, session.in_transaction())
|
|
||||||
if session.in_transaction():
|
|
||||||
yield session
|
|
||||||
else:
|
|
||||||
async with session.begin():
|
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
@@ -211,7 +180,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_transcript(self, data):
|
async def on_transcript(self, data):
|
||||||
async with self.transaction() as session:
|
async with self.locked_session() as session:
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript(session)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
session,
|
||||||
@@ -231,7 +200,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
)
|
)
|
||||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||||
topic.id = data.id
|
topic.id = data.id
|
||||||
async with self.transaction() as session:
|
async with self.locked_session() as session:
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript(session)
|
||||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
await transcripts_controller.upsert_topic(session, transcript, topic)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
@@ -244,7 +213,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_title(self, data):
|
async def on_title(self, data):
|
||||||
final_title = TranscriptFinalTitle(title=data.title)
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
async with self.transaction() as session:
|
async with self.locked_session() as session:
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript(session)
|
||||||
if not transcript.title:
|
if not transcript.title:
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
@@ -264,7 +233,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_long_summary(self, data):
|
async def on_long_summary(self, data):
|
||||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||||
async with self.transaction() as session:
|
async with self.locked_session() as session:
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript(session)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
session,
|
||||||
@@ -285,7 +254,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
final_short_summary = TranscriptFinalShortSummary(
|
final_short_summary = TranscriptFinalShortSummary(
|
||||||
short_summary=data.short_summary
|
short_summary=data.short_summary
|
||||||
)
|
)
|
||||||
async with self.transaction() as session:
|
async with self.locked_session() as session:
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript(session)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
session,
|
||||||
@@ -303,7 +272,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_duration(self, data):
|
async def on_duration(self, data):
|
||||||
async with self.transaction() as session:
|
async with self.locked_session() as session:
|
||||||
duration = TranscriptDuration(duration=data)
|
duration = TranscriptDuration(duration=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript(session)
|
||||||
@@ -320,7 +289,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_waveform(self, data):
|
async def on_waveform(self, data):
|
||||||
async with self.transaction() as session:
|
async with self.locked_session() as session:
|
||||||
waveform = TranscriptWaveform(waveform=data)
|
waveform = TranscriptWaveform(waveform=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript(session)
|
||||||
@@ -483,8 +452,7 @@ class PipelineMainWaveform(PipelineMainFromTopics):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
|
||||||
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
|
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
|
||||||
logger.info("Starting remove upload")
|
logger.info("Starting remove upload")
|
||||||
uploads = transcript.data_path.glob("upload.*")
|
uploads = transcript.data_path.glob("upload.*")
|
||||||
@@ -493,16 +461,14 @@ async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Remove upload done")
|
logger.info("Remove upload done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
|
||||||
logger.info("Starting waveform")
|
logger.info("Starting waveform")
|
||||||
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
logger.info("Waveform done")
|
logger.info("Waveform done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
|
||||||
logger.info("Starting convert to mp3")
|
logger.info("Starting convert to mp3")
|
||||||
|
|
||||||
# If the audio wav is not available, just skip
|
# If the audio wav is not available, just skip
|
||||||
@@ -551,24 +517,21 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Upload mp3 done")
|
logger.info("Upload mp3 done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
|
||||||
logger.info("Starting diarization")
|
logger.info("Starting diarization")
|
||||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
logger.info("Diarization done")
|
logger.info("Diarization done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_title(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_title(transcript: Transcript, logger: Logger):
|
|
||||||
logger.info("Starting title")
|
logger.info("Starting title")
|
||||||
runner = PipelineMainTitle(transcript_id=transcript.id)
|
runner = PipelineMainTitle(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
logger.info("Title done")
|
logger.info("Title done")
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
|
||||||
logger.info("Starting summaries")
|
logger.info("Starting summaries")
|
||||||
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
||||||
await runner.run()
|
await runner.run()
|
||||||
@@ -703,18 +666,27 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
|||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
async def task_pipeline_remove_upload(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
await pipeline_remove_upload(transcript_id=transcript_id)
|
async def task_pipeline_remove_upload(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
async def task_pipeline_waveform(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
await pipeline_waveform(transcript_id=transcript_id)
|
async def task_pipeline_waveform(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
await pipeline_waveform(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
async def task_pipeline_convert_to_mp3(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
@@ -726,32 +698,39 @@ async def task_pipeline_upload_mp3(
|
|||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
async def task_pipeline_diarization(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
await pipeline_diarization(transcript_id=transcript_id)
|
async def task_pipeline_diarization(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
@taskiq_broker.task
|
await pipeline_diarization(session, transcript=transcript, logger=logger)
|
||||||
async def task_pipeline_title(*, transcript_id: str):
|
|
||||||
await pipeline_title(transcript_id=transcript_id)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
|
||||||
async def task_pipeline_final_summaries(*, transcript_id: str):
|
|
||||||
await pipeline_summaries(transcript_id=transcript_id)
|
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
@with_session_and_transcript
|
@with_session_and_transcript
|
||||||
async def task_cleanup_consent(
|
async def task_pipeline_title(
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
):
|
):
|
||||||
|
await pipeline_title(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
|
@taskiq_broker.task
|
||||||
|
@with_session_and_transcript
|
||||||
|
async def task_pipeline_final_summaries(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
await pipeline_summaries(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
|
@taskiq_broker.task
|
||||||
|
@with_session_and_transcript
|
||||||
|
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
|
||||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
@with_session_and_transcript
|
@with_session_and_transcript
|
||||||
async def task_pipeline_post_to_zulip(
|
async def task_pipeline_post_to_zulip(
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
session, *, transcript: Transcript, logger: Logger
|
||||||
):
|
):
|
||||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
@@ -759,7 +738,7 @@ async def task_pipeline_post_to_zulip(
|
|||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
@with_session_and_transcript
|
@with_session_and_transcript
|
||||||
async def task_cleanup_consent_taskiq(
|
async def task_cleanup_consent_taskiq(
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
session, *, transcript: Transcript, logger: Logger
|
||||||
):
|
):
|
||||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
await cleanup_consent(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
@@ -767,7 +746,7 @@ async def task_cleanup_consent_taskiq(
|
|||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
@with_session_and_transcript
|
@with_session_and_transcript
|
||||||
async def task_pipeline_post_to_zulip_taskiq(
|
async def task_pipeline_post_to_zulip_taskiq(
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
session, *, transcript: Transcript, logger: Logger
|
||||||
):
|
):
|
||||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
||||||
|
|
||||||
@@ -793,11 +772,9 @@ async def task_pipeline_post_sequential(*, transcript_id: str):
|
|||||||
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
|
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
@get_transcript
|
async def pipeline_process(session, transcript: Transcript, logger: Logger):
|
||||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
|
||||||
try:
|
try:
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
async with get_session_context() as session:
|
|
||||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
@@ -838,7 +815,6 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
|||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Pipeline error", exc_info=exc)
|
logger.error("Pipeline error", exc_info=exc)
|
||||||
async with get_session_context() as session:
|
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
@@ -852,5 +828,8 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
|||||||
|
|
||||||
|
|
||||||
@taskiq_broker.task
|
@taskiq_broker.task
|
||||||
async def task_pipeline_process(*, transcript_id: str):
|
@with_session_and_transcript
|
||||||
return await pipeline_process(transcript_id=transcript_id)
|
async def task_pipeline_process(
|
||||||
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
||||||
|
):
|
||||||
|
return await pipeline_process(session, transcript=transcript, logger=logger)
|
||||||
|
|||||||
@@ -71,16 +71,13 @@ def with_session_and_transcript(func: F) -> F:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async with get_session_context() as session:
|
async with get_session_context() as session:
|
||||||
# Fetch the transcript
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
|
|
||||||
# Create enhanced logger
|
|
||||||
tlogger = logger.bind(transcript_id=transcript.id)
|
tlogger = logger.bind(transcript_id=transcript.id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Pass session, transcript, and logger to the decorated function
|
|
||||||
return await func(
|
return await func(
|
||||||
session, transcript=transcript, logger=tlogger, *args, **kwargs
|
session, transcript=transcript, logger=tlogger, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -377,6 +377,29 @@ async def dummy_storage():
|
|||||||
# yield
|
# yield
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture()
|
||||||
|
# async def db_session(sqla_engine):
|
||||||
|
# """
|
||||||
|
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
|
||||||
|
# after the test completes.
|
||||||
|
# """
|
||||||
|
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
# from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
# connection = await sqla_engine.connect()
|
||||||
|
# trans = await connection.begin()
|
||||||
|
|
||||||
|
# Session = sessionmaker(connection, expire_on_commit=False, class_=AsyncSession)
|
||||||
|
# session = Session()
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# yield session
|
||||||
|
# finally:
|
||||||
|
# await session.close()
|
||||||
|
# await trans.rollback()
|
||||||
|
# await connection.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def ensure_db_session_in_app(db_session):
|
async def ensure_db_session_in_app(db_session):
|
||||||
async def mock_get_session():
|
async def mock_get_session():
|
||||||
|
|||||||
Reference in New Issue
Block a user