fix: add missing db_session parameters across codebase

- Add @with_session decorator to webhook.py send_transcript_webhook task
- Update tools/process.py to use get_session_factory instead of deprecated get_database
- Fix tests/conftest.py fixture to pass db_session to controller update
- Fix main_live_pipeline.py to create sessions for controller update calls
- Update exportdanswer.py and exportdb.py to use new session pattern with get_session_factory
- Ensure all transcripts_controller and rooms_controller calls include session parameter
This commit is contained in:
2025-09-23 19:12:34 -06:00
parent df909363f5
commit 2aa99fe846
6 changed files with 56 additions and 42 deletions

View File

@@ -799,9 +799,11 @@ def pipeline_post(*, transcript_id: str):
async def pipeline_process(transcript: Transcript, logger: Logger): async def pipeline_process(transcript: Transcript, logger: Logger):
try: try:
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
async with get_session_factory()() as session:
await transcripts_controller.download_mp3_from_storage(transcript) await transcripts_controller.download_mp3_from_storage(transcript)
transcript.audio_waveform_filename.unlink(missing_ok=True) transcript.audio_waveform_filename.unlink(missing_ok=True)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": [], "topics": [],
@@ -838,7 +840,9 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
except Exception as exc: except Exception as exc:
logger.error("Pipeline error", exc_info=exc) logger.error("Pipeline error", exc_info=exc)
async with get_session_factory()() as session:
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"status": "error", "status": "error",

View File

@@ -9,12 +9,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve() filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}" settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_database, transcripts from reflector.db import get_session_factory
from reflector.db.transcripts import transcripts_controller
database = get_database() session_factory = get_session_factory()
await database.connect() async with session_factory() as session:
transcripts = await database.fetch_all(transcripts.select()) transcripts = await transcripts_controller.get_all(session)
await database.disconnect()
def export_transcript(transcript, output_dir): def export_transcript(transcript, output_dir):
for topic in transcript.topics: for topic in transcript.topics:

View File

@@ -8,12 +8,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve() filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}" settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_database, transcripts from reflector.db import get_session_factory
from reflector.db.transcripts import transcripts_controller
database = get_database() session_factory = get_session_factory()
await database.connect() async with session_factory() as session:
transcripts = await database.fetch_all(transcripts.select()) transcripts = await transcripts_controller.get_all(session)
await database.disconnect()
def export_transcript(transcript): def export_transcript(transcript):
tid = transcript.id tid = transcript.id

View File

@@ -11,6 +11,9 @@ import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal from typing import Any, Dict, List, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_factory
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
from reflector.logger import logger from reflector.logger import logger
from reflector.pipelines.main_file_pipeline import ( from reflector.pipelines.main_file_pipeline import (
@@ -50,6 +53,7 @@ TranscriptId = str
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system) # common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
# ideally we want to get rid of it at some point # ideally we want to get rid of it at some point
async def prepare_entry( async def prepare_entry(
session: AsyncSession,
source_path: str, source_path: str,
source_language: str, source_language: str,
target_language: str, target_language: str,
@@ -57,6 +61,7 @@ async def prepare_entry(
file_path = Path(source_path) file_path = Path(source_path)
transcript = await transcripts_controller.add( transcript = await transcripts_controller.add(
session,
file_path.name, file_path.name,
# note that the real file upload has SourceKind: LIVE for the reason of it's an error # note that the real file upload has SourceKind: LIVE for the reason of it's an error
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
@@ -78,16 +83,20 @@ async def prepare_entry(
logger.info(f"Copied {source_path} to {upload_path}") logger.info(f"Copied {source_path} to {upload_path}")
# pipelines expect entity status "uploaded" # pipelines expect entity status "uploaded"
await transcripts_controller.update(transcript, {"status": "uploaded"}) await transcripts_controller.update(session, transcript, {"status": "uploaded"})
return transcript.id return transcript.id
# same reason as prepare_entry # same reason as prepare_entry
async def extract_result_from_entry( async def extract_result_from_entry(
transcript_id: TranscriptId, output_path: str session: AsyncSession,
transcript_id: TranscriptId,
output_path: str,
) -> None: ) -> None:
post_final_transcript = await transcripts_controller.get_by_id(transcript_id) post_final_transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
# assert post_final_transcript.status == "ended" # assert post_final_transcript.status == "ended"
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582 # File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
@@ -115,6 +124,7 @@ async def extract_result_from_entry(
async def process_live_pipeline( async def process_live_pipeline(
session: AsyncSession,
transcript_id: TranscriptId, transcript_id: TranscriptId,
): ):
"""Process transcript_id with transcription and diarization""" """Process transcript_id with transcription and diarization"""
@@ -123,7 +133,9 @@ async def process_live_pipeline(
await live_pipeline_process(transcript_id=transcript_id) await live_pipeline_process(transcript_id=transcript_id)
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr) print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id) pre_final_transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post # assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
assert pre_final_transcript.status != "ended" assert pre_final_transcript.status != "ended"
@@ -160,21 +172,17 @@ async def process(
pipeline: Literal["live", "file"], pipeline: Literal["live", "file"],
output_path: str = None, output_path: str = None,
): ):
from reflector.db import get_database session_factory = get_session_factory()
async with session_factory() as session:
database = get_database()
# db connect is a part of ceremony
await database.connect()
try:
transcript_id = await prepare_entry( transcript_id = await prepare_entry(
session,
source_path, source_path,
source_language, source_language,
target_language, target_language,
) )
pipeline_handlers = { pipeline_handlers = {
"live": process_live_pipeline, "live": lambda tid: process_live_pipeline(session, tid),
"file": process_file_pipeline, "file": process_file_pipeline,
} }
@@ -184,9 +192,7 @@ async def process(
await handler(transcript_id) await handler(transcript_id)
await extract_result_from_entry(transcript_id, output_path) await extract_result_from_entry(session, transcript_id, output_path)
finally:
await database.disconnect()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -10,12 +10,14 @@ import httpx
import structlog import structlog
from celery import shared_task from celery import shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import asynctask from reflector.pipelines.main_live_pipeline import asynctask
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.webvtt import topics_to_webvtt from reflector.utils.webvtt import topics_to_webvtt
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
@@ -39,11 +41,13 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s
retry_backoff_max=3600, # Max 1 hour between retries retry_backoff_max=3600, # Max 1 hour between retries
) )
@asynctask @asynctask
@with_session
async def send_transcript_webhook( async def send_transcript_webhook(
self, self,
transcript_id: str, transcript_id: str,
room_id: str, room_id: str,
event_id: str, event_id: str,
session: AsyncSession,
): ):
log = logger.bind( log = logger.bind(
transcript_id=transcript_id, transcript_id=transcript_id,
@@ -53,12 +57,12 @@ async def send_transcript_webhook(
try: try:
# Fetch transcript and room # Fetch transcript and room
transcript = await transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
log.error("Transcript not found, skipping webhook") log.error("Transcript not found, skipping webhook")
return return
room = await rooms_controller.get_by_id(room_id) room = await rooms_controller.get_by_id(session, room_id)
if not room: if not room:
log.error("Room not found, skipping webhook") log.error("Room not found, skipping webhook")
return return

View File

@@ -383,7 +383,7 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
transcript = await transcripts_controller.get_by_id(db_session, tid) transcript = await transcripts_controller.get_by_id(db_session, tid)
assert transcript is not None assert transcript is not None
await transcripts_controller.update(transcript, {"status": "ended"}) await transcripts_controller.update(db_session, transcript, {"status": "ended"})
# manually copy a file at the expected location # manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename audio_filename = transcript.audio_mp3_filename