diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 5f2647e7..882e8cf9 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -799,14 +799,16 @@ def pipeline_post(*, transcript_id: str): async def pipeline_process(transcript: Transcript, logger: Logger): try: if transcript.audio_location == "storage": - await transcripts_controller.download_mp3_from_storage(transcript) - transcript.audio_waveform_filename.unlink(missing_ok=True) - await transcripts_controller.update( - transcript, - { - "topics": [], - }, - ) + async with get_session_factory()() as session: + await transcripts_controller.download_mp3_from_storage(transcript) + transcript.audio_waveform_filename.unlink(missing_ok=True) + await transcripts_controller.update( + session, + transcript, + { + "topics": [], + }, + ) # open audio audio_filename = next(transcript.data_path.glob("upload.*"), None) @@ -838,12 +840,14 @@ async def pipeline_process(transcript: Transcript, logger: Logger): except Exception as exc: logger.error("Pipeline error", exc_info=exc) - await transcripts_controller.update( - transcript, - { - "status": "error", - }, - ) + async with get_session_factory()() as session: + await transcripts_controller.update( + session, + transcript, + { + "status": "error", + }, + ) raise logger.info("Pipeline ended") diff --git a/server/reflector/tools/exportdanswer.py b/server/reflector/tools/exportdanswer.py index 6d335079..2f84242a 100644 --- a/server/reflector/tools/exportdanswer.py +++ b/server/reflector/tools/exportdanswer.py @@ -9,12 +9,12 @@ async def export_db(filename: str) -> None: filename = pathlib.Path(filename).resolve() settings.DATABASE_URL = f"sqlite:///{filename}" - from reflector.db import get_database, transcripts + from reflector.db import get_session_factory + from reflector.db.transcripts import transcripts_controller - database = get_database() - await database.connect() - transcripts = await database.fetch_all(transcripts.select()) - await database.disconnect() + session_factory = get_session_factory() + async with session_factory() as session: + transcripts = await transcripts_controller.get_all(session) def export_transcript(transcript, output_dir): for topic in transcript.topics: diff --git a/server/reflector/tools/exportdb.py b/server/reflector/tools/exportdb.py index 3f37c79e..2948813c 100644 --- a/server/reflector/tools/exportdb.py +++ b/server/reflector/tools/exportdb.py @@ -8,12 +8,12 @@ async def export_db(filename: str) -> None: filename = pathlib.Path(filename).resolve() settings.DATABASE_URL = f"sqlite:///{filename}" - from reflector.db import get_database, transcripts + from reflector.db import get_session_factory + from reflector.db.transcripts import transcripts_controller - database = get_database() - await database.connect() - transcripts = await database.fetch_all(transcripts.select()) - await database.disconnect() + session_factory = get_session_factory() + async with session_factory() as session: + transcripts = await transcripts_controller.get_all(session) def export_transcript(transcript): tid = transcript.id diff --git a/server/reflector/tools/process.py b/server/reflector/tools/process.py index eb770f76..e5b9ba4d 100644 --- a/server/reflector/tools/process.py +++ b/server/reflector/tools/process.py @@ -11,6 +11,9 @@ import time from pathlib import Path from typing import Any, Dict, List, Literal +from sqlalchemy.ext.asyncio import AsyncSession + +from reflector.db import get_session_factory from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller from reflector.logger import logger from reflector.pipelines.main_file_pipeline import ( @@ -50,6 +53,7 @@ TranscriptId = str # common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system) # ideally we want to get rid of it at some point async def prepare_entry( + session: AsyncSession, source_path: str, source_language: str, target_language: str, @@ -57,6 +61,7 @@ async def prepare_entry( file_path = Path(source_path) transcript = await transcripts_controller.add( + session, file_path.name, # note that the real file upload has SourceKind: LIVE for the reason of it's an error source_kind=SourceKind.FILE, @@ -78,16 +83,20 @@ async def prepare_entry( logger.info(f"Copied {source_path} to {upload_path}") # pipelines expect entity status "uploaded" - await transcripts_controller.update(transcript, {"status": "uploaded"}) + await transcripts_controller.update(session, transcript, {"status": "uploaded"}) return transcript.id # same reason as prepare_entry async def extract_result_from_entry( - transcript_id: TranscriptId, output_path: str + session: AsyncSession, + transcript_id: TranscriptId, + output_path: str, ) -> None: - post_final_transcript = await transcripts_controller.get_by_id(transcript_id) + post_final_transcript = await transcripts_controller.get_by_id( + session, transcript_id + ) # assert post_final_transcript.status == "ended" # File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582 @@ -115,6 +124,7 @@ async def extract_result_from_entry( async def process_live_pipeline( + session: AsyncSession, transcript_id: TranscriptId, ): """Process transcript_id with transcription and diarization""" @@ -123,7 +133,9 @@ async def process_live_pipeline( await live_pipeline_process(transcript_id=transcript_id) print(f"Processing complete for transcript {transcript_id}", file=sys.stderr) - pre_final_transcript = await transcripts_controller.get_by_id(transcript_id) + pre_final_transcript = await transcripts_controller.get_by_id( + session, transcript_id + ) # assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post assert pre_final_transcript.status != "ended" @@ -160,21 +172,17 @@ async def process( pipeline: Literal["live", "file"], output_path: str = None, ): - from reflector.db import get_database - - database = get_database() - # db connect is a part of ceremony - await database.connect() - - try: + session_factory = get_session_factory() + async with session_factory() as session: transcript_id = await prepare_entry( + session, source_path, source_language, target_language, ) pipeline_handlers = { - "live": process_live_pipeline, + "live": lambda tid: process_live_pipeline(session, tid), "file": process_file_pipeline, } @@ -184,9 +192,7 @@ async def process( await handler(transcript_id) - await extract_result_from_entry(transcript_id, output_path) - finally: - await database.disconnect() + await extract_result_from_entry(session, transcript_id, output_path) if __name__ == "__main__": diff --git a/server/reflector/worker/webhook.py b/server/reflector/worker/webhook.py index 64368b2e..81c2ecb2 100644 --- a/server/reflector/worker/webhook.py +++ b/server/reflector/worker/webhook.py @@ -10,12 +10,14 @@ import httpx import structlog from celery import shared_task from celery.utils.log import get_task_logger +from sqlalchemy.ext.asyncio import AsyncSession from reflector.db.rooms import rooms_controller from reflector.db.transcripts import transcripts_controller from reflector.pipelines.main_live_pipeline import asynctask from reflector.settings import settings from reflector.utils.webvtt import topics_to_webvtt +from reflector.worker.session_decorator import with_session logger = structlog.wrap_logger(get_task_logger(__name__)) @@ -39,11 +41,13 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s retry_backoff_max=3600, # Max 1 hour between retries ) @asynctask +@with_session async def send_transcript_webhook( self, transcript_id: str, room_id: str, event_id: str, + session: AsyncSession, ): log = logger.bind( transcript_id=transcript_id, @@ -53,12 +57,12 @@ async def send_transcript_webhook( try: # Fetch transcript and room - transcript = await transcripts_controller.get_by_id(transcript_id) + transcript = await transcripts_controller.get_by_id(session, transcript_id) if not transcript: log.error("Transcript not found, skipping webhook") return - room = await rooms_controller.get_by_id(room_id) + room = await rooms_controller.get_by_id(session, room_id) if not room: log.error("Room not found, skipping webhook") return diff --git a/server/tests/conftest.py b/server/tests/conftest.py index c11831c6..84d4d3ec 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -383,7 +383,7 @@ async def fake_transcript_with_topics(tmpdir, client, db_session): transcript = await transcripts_controller.get_by_id(db_session, tid) assert transcript is not None - await transcripts_controller.update(transcript, {"status": "ended"}) + await transcripts_controller.update(db_session, transcript, {"status": "ended"}) # manually copy a file at the expected location audio_filename = transcript.audio_mp3_filename