fix: Complete SQLAlchemy 2.0 migration - add session parameters to all controller calls

- Add session parameter to all view functions and controller calls
- Fix pipeline files to use get_session_factory() for background tasks
- Update PipelineMainBase and PipelineMainFile to handle sessions properly
- Add missing on_* methods to PipelineMainFile class
- Fix test fixtures to handle docker services availability
- Add docker_ip fixture for test database connections
- Import fixes for transcripts_controller in tests

All controller calls now properly use sessions as first parameter per SQLAlchemy 2.0 async patterns.
This commit is contained in:
2025-09-18 13:08:19 -06:00
parent 45d1608950
commit d21b65e4e8
13 changed files with 593 additions and 550 deletions

View File

@@ -8,18 +8,22 @@ Uses parallel processing for transcription, diarization, and waveform generation
import asyncio
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
import av
import structlog
from celery import chain, shared_task
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import (
SourceKind,
Transcript,
TranscriptStatus,
TranscriptTopic,
transcripts_controller,
)
from reflector.logger import logger
@@ -83,6 +87,32 @@ class PipelineMainFile(PipelineMainBase):
self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger)
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
"""Get transcript with session"""
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
@asynccontextmanager
async def lock_transaction(self):
# This lock is to prevent multiple processor starting adding
# into event array at the same time
async with asyncio.Lock():
yield
@asynccontextmanager
async def transaction(self):
async with self.lock_transaction():
async with get_session_factory()() as session:
yield session
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
for i, result in enumerate(results):
@@ -97,17 +127,23 @@ class PipelineMainFile(PipelineMainBase):
@broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status)
async with get_session_factory()() as session:
return await transcripts_controller.set_status(
session, transcript_id, status
)
async def process(self, file_path: Path):
"""Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}")
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
# Clear transcript as we're going to regenerate everything
async with self.transaction():
# Clear transcript as we're going to regenerate everything
await transcripts_controller.update(
session,
transcript,
{
"events": [],
@@ -131,7 +167,8 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete")
await transcripts_controller.set_status(transcript.id, "ended")
async with get_session_factory()() as session:
await transcripts_controller.set_status(session, transcript.id, "ended")
async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript
@@ -308,7 +345,10 @@ class PipelineMainFile(PipelineMainBase):
async def generate_waveform(self, audio_path: Path):
"""Generate and save waveform"""
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = AudioWaveformProcessor(
audio_path=audio_path,
@@ -367,7 +407,10 @@ class PipelineMainFile(PipelineMainBase):
self.logger.warning("No topics for summary generation")
return
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = TranscriptFinalSummaryProcessor(
transcript=transcript,
callback=self.on_long_summary,
@@ -380,37 +423,144 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush()
async def on_topic(self, topic: TitleSummary):
"""Handle topic event - save to database"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
topic_obj = TranscriptTopic(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
)
await transcripts_controller.upsert_topic(session, transcript, topic_obj)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="TOPIC",
data=topic_obj,
)
async def on_title(self, data):
"""Handle title event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not transcript.title:
await transcripts_controller.update(
session,
transcript,
{"title": data.title},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_TITLE",
data={"title": data.title},
)
async def on_long_summary(self, data):
"""Handle long summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"long_summary": data.long_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data={"long_summary": data.long_summary},
)
async def on_short_summary(self, data):
"""Handle short summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"short_summary": data.short_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data={"short_summary": data.short_summary},
)
async def on_duration(self, duration):
"""Handle duration event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"duration": duration},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="DURATION",
data={"duration": duration},
)
async def on_waveform(self, waveform):
"""Handle waveform event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="WAVEFORM",
data={"waveform": waveform},
)
@shared_task
@asynctask
async def task_send_webhook_if_needed(*, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured"""
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
return
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(transcript.room_id)
if room and room.webhook_url:
logger.info(
"Dispatching webhook",
transcript_id=transcript_id,
room_id=room.id,
webhook_url=room.webhook_url,
)
send_transcript_webhook.delay(
transcript_id, room.id, event_id=uuid.uuid4().hex
)
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(session, transcript.room_id)
if room and room.webhook_url:
logger.info(
"Dispatching webhook",
transcript_id=transcript_id,
room_id=room.id,
webhook_url=room.webhook_url,
)
send_transcript_webhook.delay(
transcript_id, room.id, event_id=uuid.uuid4().hex
)
@shared_task
@asynctask
async def task_pipeline_file_process(*, transcript_id: str):
"""Celery task for file pipeline processing"""
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
pipeline = PipelineMainFile(transcript_id=transcript_id)
try: