fix: add missing db_session parameters across codebase

- Add @with_session decorator to webhook.py send_transcript_webhook task
- Update tools/process.py to use get_session_factory instead of deprecated get_database
- Fix tests/conftest.py fixture to pass db_session to controller update
- Fix main_live_pipeline.py to create sessions for controller update calls
- Update exportdanswer.py and exportdb.py to use new session pattern with get_session_factory
- Ensure all transcripts_controller and rooms_controller calls include session parameter
This commit is contained in:
2025-09-23 19:12:34 -06:00
parent df909363f5
commit 2aa99fe846
6 changed files with 56 additions and 42 deletions

View File

@@ -799,14 +799,16 @@ def pipeline_post(*, transcript_id: str):
async def pipeline_process(transcript: Transcript, logger: Logger):
try:
if transcript.audio_location == "storage":
await transcripts_controller.download_mp3_from_storage(transcript)
transcript.audio_waveform_filename.unlink(missing_ok=True)
await transcripts_controller.update(
transcript,
{
"topics": [],
},
)
async with get_session_factory()() as session:
await transcripts_controller.download_mp3_from_storage(transcript)
transcript.audio_waveform_filename.unlink(missing_ok=True)
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
},
)
# open audio
audio_filename = next(transcript.data_path.glob("upload.*"), None)
@@ -838,12 +840,14 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
except Exception as exc:
logger.error("Pipeline error", exc_info=exc)
await transcripts_controller.update(
transcript,
{
"status": "error",
},
)
async with get_session_factory()() as session:
await transcripts_controller.update(
session,
transcript,
{
"status": "error",
},
)
raise
logger.info("Pipeline ended")

View File

@@ -9,12 +9,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_database, transcripts
from reflector.db import get_session_factory
from reflector.db.transcripts import transcripts_controller
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()
session_factory = get_session_factory()
async with session_factory() as session:
transcripts = await transcripts_controller.get_all(session)
def export_transcript(transcript, output_dir):
for topic in transcript.topics:

View File

@@ -8,12 +8,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_database, transcripts
from reflector.db import get_session_factory
from reflector.db.transcripts import transcripts_controller
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()
session_factory = get_session_factory()
async with session_factory() as session:
transcripts = await transcripts_controller.get_all(session)
def export_transcript(transcript):
tid = transcript.id

View File

@@ -11,6 +11,9 @@ import time
from pathlib import Path
from typing import Any, Dict, List, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_factory
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
from reflector.logger import logger
from reflector.pipelines.main_file_pipeline import (
@@ -50,6 +53,7 @@ TranscriptId = str
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
# ideally we want to get rid of it at some point
async def prepare_entry(
session: AsyncSession,
source_path: str,
source_language: str,
target_language: str,
@@ -57,6 +61,7 @@ async def prepare_entry(
file_path = Path(source_path)
transcript = await transcripts_controller.add(
session,
file_path.name,
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
source_kind=SourceKind.FILE,
@@ -78,16 +83,20 @@ async def prepare_entry(
logger.info(f"Copied {source_path} to {upload_path}")
# pipelines expect entity status "uploaded"
await transcripts_controller.update(transcript, {"status": "uploaded"})
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
return transcript.id
# same reason as prepare_entry
async def extract_result_from_entry(
transcript_id: TranscriptId, output_path: str
session: AsyncSession,
transcript_id: TranscriptId,
output_path: str,
) -> None:
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
post_final_transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
# assert post_final_transcript.status == "ended"
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
@@ -115,6 +124,7 @@ async def extract_result_from_entry(
async def process_live_pipeline(
session: AsyncSession,
transcript_id: TranscriptId,
):
"""Process transcript_id with transcription and diarization"""
@@ -123,7 +133,9 @@ async def process_live_pipeline(
await live_pipeline_process(transcript_id=transcript_id)
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
pre_final_transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
assert pre_final_transcript.status != "ended"
@@ -160,21 +172,17 @@ async def process(
pipeline: Literal["live", "file"],
output_path: str = None,
):
from reflector.db import get_database
database = get_database()
# db connect is a part of ceremony
await database.connect()
try:
session_factory = get_session_factory()
async with session_factory() as session:
transcript_id = await prepare_entry(
session,
source_path,
source_language,
target_language,
)
pipeline_handlers = {
"live": process_live_pipeline,
"live": lambda tid: process_live_pipeline(session, tid),
"file": process_file_pipeline,
}
@@ -184,9 +192,7 @@ async def process(
await handler(transcript_id)
await extract_result_from_entry(transcript_id, output_path)
finally:
await database.disconnect()
await extract_result_from_entry(session, transcript_id, output_path)
if __name__ == "__main__":

View File

@@ -10,12 +10,14 @@ import httpx
import structlog
from celery import shared_task
from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import asynctask
from reflector.settings import settings
from reflector.utils.webvtt import topics_to_webvtt
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__))
@@ -39,11 +41,13 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s
retry_backoff_max=3600, # Max 1 hour between retries
)
@asynctask
@with_session
async def send_transcript_webhook(
self,
transcript_id: str,
room_id: str,
event_id: str,
session: AsyncSession,
):
log = logger.bind(
transcript_id=transcript_id,
@@ -53,12 +57,12 @@ async def send_transcript_webhook(
try:
# Fetch transcript and room
transcript = await transcripts_controller.get_by_id(transcript_id)
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
log.error("Transcript not found, skipping webhook")
return
room = await rooms_controller.get_by_id(room_id)
room = await rooms_controller.get_by_id(session, room_id)
if not room:
log.error("Room not found, skipping webhook")
return

View File

@@ -383,7 +383,7 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
transcript = await transcripts_controller.get_by_id(db_session, tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "ended"})
await transcripts_controller.update(db_session, transcript, {"status": "ended"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename