mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
fix: Complete SQLAlchemy 2.0 migration - add session parameters to all controller calls
- Add session parameter to all view functions and controller calls - Fix pipeline files to use get_session_factory() for background tasks - Update PipelineMainBase and PipelineMainFile to handle sessions properly - Add missing on_* methods to PipelineMainFile class - Fix test fixtures to handle docker services availability - Add docker_ip fixture for test database connections - Import fixes for transcripts_controller in tests All controller calls now properly use sessions as first parameter per SQLAlchemy 2.0 async patterns.
This commit is contained in:
@@ -8,18 +8,22 @@ Uses parallel processing for transcription, diarization, and waveform generation
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import structlog
|
||||
from celery import chain, shared_task
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
Transcript,
|
||||
TranscriptStatus,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
@@ -83,6 +87,32 @@ class PipelineMainFile(PipelineMainBase):
|
||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||
|
||||
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
||||
"""Get transcript with session"""
|
||||
if session:
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
else:
|
||||
async with get_session_factory()() as session:
|
||||
result = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock_transaction(self):
|
||||
# This lock is to prevent multiple processor starting adding
|
||||
# into event array at the same time
|
||||
async with asyncio.Lock():
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self.lock_transaction():
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
|
||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||
for i, result in enumerate(results):
|
||||
@@ -97,17 +127,23 @@ class PipelineMainFile(PipelineMainBase):
|
||||
@broadcast_to_sockets
|
||||
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||
async with self.lock_transaction():
|
||||
return await transcripts_controller.set_status(transcript_id, status)
|
||||
async with get_session_factory()() as session:
|
||||
return await transcripts_controller.set_status(
|
||||
session, transcript_id, status
|
||||
)
|
||||
|
||||
async def process(self, file_path: Path):
|
||||
"""Main entry point for file processing"""
|
||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
async with self.transaction():
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
@@ -131,7 +167,8 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
self.logger.info("File pipeline complete")
|
||||
|
||||
await transcripts_controller.set_status(transcript.id, "ended")
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.set_status(session, transcript.id, "ended")
|
||||
|
||||
async def extract_and_write_audio(
|
||||
self, file_path: Path, transcript: Transcript
|
||||
@@ -308,7 +345,10 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
async def generate_waveform(self, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
@@ -367,7 +407,10 @@ class PipelineMainFile(PipelineMainBase):
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
@@ -380,37 +423,144 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def on_topic(self, topic: TitleSummary):
|
||||
"""Handle topic event - save to database"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
topic_obj = TranscriptTopic(
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
)
|
||||
await transcripts_controller.upsert_topic(session, transcript, topic_obj)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic_obj,
|
||||
)
|
||||
|
||||
async def on_title(self, data):
|
||||
"""Handle title event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"title": data.title},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data={"title": data.title},
|
||||
)
|
||||
|
||||
async def on_long_summary(self, data):
|
||||
"""Handle long summary event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"long_summary": data.long_summary},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data={"long_summary": data.long_summary},
|
||||
)
|
||||
|
||||
async def on_short_summary(self, data):
|
||||
"""Handle short summary event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"short_summary": data.short_summary},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data={"short_summary": data.short_summary},
|
||||
)
|
||||
|
||||
async def on_duration(self, duration):
|
||||
"""Handle duration event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"duration": duration},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="DURATION",
|
||||
data={"duration": duration},
|
||||
)
|
||||
|
||||
async def on_waveform(self, waveform):
|
||||
"""Handle waveform event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="WAVEFORM",
|
||||
data={"waveform": waveform},
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_send_webhook_if_needed(*, transcript_id: str):
|
||||
"""Send webhook if this is a room recording with webhook configured"""
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_file_process(*, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user