mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
fix: add missing db_session parameters across codebase
- Add @with_session decorator to webhook.py send_transcript_webhook task - Update tools/process.py to use get_session_factory instead of deprecated get_database - Fix tests/conftest.py fixture to pass db_session to controller update - Fix main_live_pipeline.py to create sessions for controller update calls - Update exportdanswer.py and exportdb.py to use new session pattern with get_session_factory - Ensure all transcripts_controller and rooms_controller calls include session parameter
This commit is contained in:
@@ -799,14 +799,16 @@ def pipeline_post(*, transcript_id: str):
|
|||||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||||
try:
|
try:
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
async with get_session_factory()() as session:
|
||||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||||
await transcripts_controller.update(
|
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||||
transcript,
|
await transcripts_controller.update(
|
||||||
{
|
session,
|
||||||
"topics": [],
|
transcript,
|
||||||
},
|
{
|
||||||
)
|
"topics": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# open audio
|
# open audio
|
||||||
audio_filename = next(transcript.data_path.glob("upload.*"), None)
|
audio_filename = next(transcript.data_path.glob("upload.*"), None)
|
||||||
@@ -838,12 +840,14 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
|||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Pipeline error", exc_info=exc)
|
logger.error("Pipeline error", exc_info=exc)
|
||||||
await transcripts_controller.update(
|
async with get_session_factory()() as session:
|
||||||
transcript,
|
await transcripts_controller.update(
|
||||||
{
|
session,
|
||||||
"status": "error",
|
transcript,
|
||||||
},
|
{
|
||||||
)
|
"status": "error",
|
||||||
|
},
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
logger.info("Pipeline ended")
|
logger.info("Pipeline ended")
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ async def export_db(filename: str) -> None:
|
|||||||
filename = pathlib.Path(filename).resolve()
|
filename = pathlib.Path(filename).resolve()
|
||||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||||
|
|
||||||
from reflector.db import get_database, transcripts
|
from reflector.db import get_session_factory
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
database = get_database()
|
session_factory = get_session_factory()
|
||||||
await database.connect()
|
async with session_factory() as session:
|
||||||
transcripts = await database.fetch_all(transcripts.select())
|
transcripts = await transcripts_controller.get_all(session)
|
||||||
await database.disconnect()
|
|
||||||
|
|
||||||
def export_transcript(transcript, output_dir):
|
def export_transcript(transcript, output_dir):
|
||||||
for topic in transcript.topics:
|
for topic in transcript.topics:
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ async def export_db(filename: str) -> None:
|
|||||||
filename = pathlib.Path(filename).resolve()
|
filename = pathlib.Path(filename).resolve()
|
||||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||||
|
|
||||||
from reflector.db import get_database, transcripts
|
from reflector.db import get_session_factory
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
database = get_database()
|
session_factory = get_session_factory()
|
||||||
await database.connect()
|
async with session_factory() as session:
|
||||||
transcripts = await database.fetch_all(transcripts.select())
|
transcripts = await transcripts_controller.get_all(session)
|
||||||
await database.disconnect()
|
|
||||||
|
|
||||||
def export_transcript(transcript):
|
def export_transcript(transcript):
|
||||||
tid = transcript.id
|
tid = transcript.id
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal
|
from typing import Any, Dict, List, Literal
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from reflector.db import get_session_factory
|
||||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.pipelines.main_file_pipeline import (
|
from reflector.pipelines.main_file_pipeline import (
|
||||||
@@ -50,6 +53,7 @@ TranscriptId = str
|
|||||||
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
||||||
# ideally we want to get rid of it at some point
|
# ideally we want to get rid of it at some point
|
||||||
async def prepare_entry(
|
async def prepare_entry(
|
||||||
|
session: AsyncSession,
|
||||||
source_path: str,
|
source_path: str,
|
||||||
source_language: str,
|
source_language: str,
|
||||||
target_language: str,
|
target_language: str,
|
||||||
@@ -57,6 +61,7 @@ async def prepare_entry(
|
|||||||
file_path = Path(source_path)
|
file_path = Path(source_path)
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
|
session,
|
||||||
file_path.name,
|
file_path.name,
|
||||||
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
@@ -78,16 +83,20 @@ async def prepare_entry(
|
|||||||
logger.info(f"Copied {source_path} to {upload_path}")
|
logger.info(f"Copied {source_path} to {upload_path}")
|
||||||
|
|
||||||
# pipelines expect entity status "uploaded"
|
# pipelines expect entity status "uploaded"
|
||||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
return transcript.id
|
return transcript.id
|
||||||
|
|
||||||
|
|
||||||
# same reason as prepare_entry
|
# same reason as prepare_entry
|
||||||
async def extract_result_from_entry(
|
async def extract_result_from_entry(
|
||||||
transcript_id: TranscriptId, output_path: str
|
session: AsyncSession,
|
||||||
|
transcript_id: TranscriptId,
|
||||||
|
output_path: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
post_final_transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
# assert post_final_transcript.status == "ended"
|
# assert post_final_transcript.status == "ended"
|
||||||
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
||||||
@@ -115,6 +124,7 @@ async def extract_result_from_entry(
|
|||||||
|
|
||||||
|
|
||||||
async def process_live_pipeline(
|
async def process_live_pipeline(
|
||||||
|
session: AsyncSession,
|
||||||
transcript_id: TranscriptId,
|
transcript_id: TranscriptId,
|
||||||
):
|
):
|
||||||
"""Process transcript_id with transcription and diarization"""
|
"""Process transcript_id with transcription and diarization"""
|
||||||
@@ -123,7 +133,9 @@ async def process_live_pipeline(
|
|||||||
await live_pipeline_process(transcript_id=transcript_id)
|
await live_pipeline_process(transcript_id=transcript_id)
|
||||||
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
||||||
|
|
||||||
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
pre_final_transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||||
assert pre_final_transcript.status != "ended"
|
assert pre_final_transcript.status != "ended"
|
||||||
@@ -160,21 +172,17 @@ async def process(
|
|||||||
pipeline: Literal["live", "file"],
|
pipeline: Literal["live", "file"],
|
||||||
output_path: str = None,
|
output_path: str = None,
|
||||||
):
|
):
|
||||||
from reflector.db import get_database
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as session:
|
||||||
database = get_database()
|
|
||||||
# db connect is a part of ceremony
|
|
||||||
await database.connect()
|
|
||||||
|
|
||||||
try:
|
|
||||||
transcript_id = await prepare_entry(
|
transcript_id = await prepare_entry(
|
||||||
|
session,
|
||||||
source_path,
|
source_path,
|
||||||
source_language,
|
source_language,
|
||||||
target_language,
|
target_language,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline_handlers = {
|
pipeline_handlers = {
|
||||||
"live": process_live_pipeline,
|
"live": lambda tid: process_live_pipeline(session, tid),
|
||||||
"file": process_file_pipeline,
|
"file": process_file_pipeline,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,9 +192,7 @@ async def process(
|
|||||||
|
|
||||||
await handler(transcript_id)
|
await handler(transcript_id)
|
||||||
|
|
||||||
await extract_result_from_entry(transcript_id, output_path)
|
await extract_result_from_entry(session, transcript_id, output_path)
|
||||||
finally:
|
|
||||||
await database.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -10,12 +10,14 @@ import httpx
|
|||||||
import structlog
|
import structlog
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.pipelines.main_live_pipeline import asynctask
|
from reflector.pipelines.main_live_pipeline import asynctask
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.webvtt import topics_to_webvtt
|
from reflector.utils.webvtt import topics_to_webvtt
|
||||||
|
from reflector.worker.session_decorator import with_session
|
||||||
|
|
||||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||||
|
|
||||||
@@ -39,11 +41,13 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s
|
|||||||
retry_backoff_max=3600, # Max 1 hour between retries
|
retry_backoff_max=3600, # Max 1 hour between retries
|
||||||
)
|
)
|
||||||
@asynctask
|
@asynctask
|
||||||
|
@with_session
|
||||||
async def send_transcript_webhook(
|
async def send_transcript_webhook(
|
||||||
self,
|
self,
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
|
session: AsyncSession,
|
||||||
):
|
):
|
||||||
log = logger.bind(
|
log = logger.bind(
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
@@ -53,12 +57,12 @@ async def send_transcript_webhook(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch transcript and room
|
# Fetch transcript and room
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
log.error("Transcript not found, skipping webhook")
|
log.error("Transcript not found, skipping webhook")
|
||||||
return
|
return
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(room_id)
|
room = await rooms_controller.get_by_id(session, room_id)
|
||||||
if not room:
|
if not room:
|
||||||
log.error("Room not found, skipping webhook")
|
log.error("Room not found, skipping webhook")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
|
|||||||
transcript = await transcripts_controller.get_by_id(db_session, tid)
|
transcript = await transcripts_controller.get_by_id(db_session, tid)
|
||||||
assert transcript is not None
|
assert transcript is not None
|
||||||
|
|
||||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
await transcripts_controller.update(db_session, transcript, {"status": "ended"})
|
||||||
|
|
||||||
# manually copy a file at the expected location
|
# manually copy a file at the expected location
|
||||||
audio_filename = transcript.audio_mp3_filename
|
audio_filename = transcript.audio_mp3_filename
|
||||||
|
|||||||
Reference in New Issue
Block a user