mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
fix: add missing db_session parameters across codebase
- Add @with_session decorator to webhook.py send_transcript_webhook task - Update tools/process.py to use get_session_factory instead of deprecated get_database - Fix tests/conftest.py fixture to pass db_session to controller update - Fix main_live_pipeline.py to create sessions for controller update calls - Update exportdanswer.py and exportdb.py to use new session pattern with get_session_factory - Ensure all transcripts_controller and rooms_controller calls include session parameter
This commit is contained in:
@@ -799,14 +799,16 @@ def pipeline_post(*, transcript_id: str):
|
||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
try:
|
||||
if transcript.audio_location == "storage":
|
||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": [],
|
||||
},
|
||||
)
|
||||
|
||||
# open audio
|
||||
audio_filename = next(transcript.data_path.glob("upload.*"), None)
|
||||
@@ -838,12 +840,14 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Pipeline error", exc_info=exc)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
},
|
||||
)
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
logger.info("Pipeline ended")
|
||||
|
||||
@@ -9,12 +9,12 @@ async def export_db(filename: str) -> None:
|
||||
filename = pathlib.Path(filename).resolve()
|
||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||
|
||||
from reflector.db import get_database, transcripts
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
transcripts = await database.fetch_all(transcripts.select())
|
||||
await database.disconnect()
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
transcripts = await transcripts_controller.get_all(session)
|
||||
|
||||
def export_transcript(transcript, output_dir):
|
||||
for topic in transcript.topics:
|
||||
|
||||
@@ -8,12 +8,12 @@ async def export_db(filename: str) -> None:
|
||||
filename = pathlib.Path(filename).resolve()
|
||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||
|
||||
from reflector.db import get_database, transcripts
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
transcripts = await database.fetch_all(transcripts.select())
|
||||
await database.disconnect()
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
transcripts = await transcripts_controller.get_all(session)
|
||||
|
||||
def export_transcript(transcript):
|
||||
tid = transcript.id
|
||||
|
||||
@@ -11,6 +11,9 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_file_pipeline import (
|
||||
@@ -50,6 +53,7 @@ TranscriptId = str
|
||||
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
||||
# ideally we want to get rid of it at some point
|
||||
async def prepare_entry(
|
||||
session: AsyncSession,
|
||||
source_path: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
@@ -57,6 +61,7 @@ async def prepare_entry(
|
||||
file_path = Path(source_path)
|
||||
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
file_path.name,
|
||||
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
||||
source_kind=SourceKind.FILE,
|
||||
@@ -78,16 +83,20 @@ async def prepare_entry(
|
||||
logger.info(f"Copied {source_path} to {upload_path}")
|
||||
|
||||
# pipelines expect entity status "uploaded"
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||
|
||||
return transcript.id
|
||||
|
||||
|
||||
# same reason as prepare_entry
|
||||
async def extract_result_from_entry(
|
||||
transcript_id: TranscriptId, output_path: str
|
||||
session: AsyncSession,
|
||||
transcript_id: TranscriptId,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
post_final_transcript = await transcripts_controller.get_by_id(
|
||||
session, transcript_id
|
||||
)
|
||||
|
||||
# assert post_final_transcript.status == "ended"
|
||||
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
||||
@@ -115,6 +124,7 @@ async def extract_result_from_entry(
|
||||
|
||||
|
||||
async def process_live_pipeline(
|
||||
session: AsyncSession,
|
||||
transcript_id: TranscriptId,
|
||||
):
|
||||
"""Process transcript_id with transcription and diarization"""
|
||||
@@ -123,7 +133,9 @@ async def process_live_pipeline(
|
||||
await live_pipeline_process(transcript_id=transcript_id)
|
||||
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
||||
|
||||
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
pre_final_transcript = await transcripts_controller.get_by_id(
|
||||
session, transcript_id
|
||||
)
|
||||
|
||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||
assert pre_final_transcript.status != "ended"
|
||||
@@ -160,21 +172,17 @@ async def process(
|
||||
pipeline: Literal["live", "file"],
|
||||
output_path: str = None,
|
||||
):
|
||||
from reflector.db import get_database
|
||||
|
||||
database = get_database()
|
||||
# db connect is a part of ceremony
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
transcript_id = await prepare_entry(
|
||||
session,
|
||||
source_path,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
|
||||
pipeline_handlers = {
|
||||
"live": process_live_pipeline,
|
||||
"live": lambda tid: process_live_pipeline(session, tid),
|
||||
"file": process_file_pipeline,
|
||||
}
|
||||
|
||||
@@ -184,9 +192,7 @@ async def process(
|
||||
|
||||
await handler(transcript_id)
|
||||
|
||||
await extract_result_from_entry(transcript_id, output_path)
|
||||
finally:
|
||||
await database.disconnect()
|
||||
await extract_result_from_entry(session, transcript_id, output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,12 +10,14 @@ import httpx
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import asynctask
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.webvtt import topics_to_webvtt
|
||||
from reflector.worker.session_decorator import with_session
|
||||
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
@@ -39,11 +41,13 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s
|
||||
retry_backoff_max=3600, # Max 1 hour between retries
|
||||
)
|
||||
@asynctask
|
||||
@with_session
|
||||
async def send_transcript_webhook(
|
||||
self,
|
||||
transcript_id: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
session: AsyncSession,
|
||||
):
|
||||
log = logger.bind(
|
||||
transcript_id=transcript_id,
|
||||
@@ -53,12 +57,12 @@ async def send_transcript_webhook(
|
||||
|
||||
try:
|
||||
# Fetch transcript and room
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
log.error("Transcript not found, skipping webhook")
|
||||
return
|
||||
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
room = await rooms_controller.get_by_id(session, room_id)
|
||||
if not room:
|
||||
log.error("Room not found, skipping webhook")
|
||||
return
|
||||
|
||||
@@ -383,7 +383,7 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
|
||||
transcript = await transcripts_controller.get_by_id(db_session, tid)
|
||||
assert transcript is not None
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||
await transcripts_controller.update(db_session, transcript, {"status": "ended"})
|
||||
|
||||
# manually copy a file at the expected location
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
|
||||
Reference in New Issue
Block a user