fix: Complete SQLAlchemy 2.0 migration - add session parameters to all controller calls

- Add session parameter to all view functions and controller calls
- Fix pipeline files to use get_session_factory() for background tasks
- Update PipelineMainBase and PipelineMainFile to handle sessions properly
- Add missing on_* methods to PipelineMainFile class
- Fix test fixtures to handle docker services availability
- Add docker_ip fixture for test database connections
- Import fixes for transcripts_controller in tests

All controller calls now properly use sessions as first parameter per SQLAlchemy 2.0 async patterns.
This commit is contained in:
2025-09-18 13:08:19 -06:00
parent 45d1608950
commit d21b65e4e8
13 changed files with 593 additions and 550 deletions

View File

@@ -8,18 +8,22 @@ Uses parallel processing for transcription, diarization, and waveform generation
import asyncio import asyncio
import uuid import uuid
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
import av import av
import structlog import structlog
from celery import chain, shared_task from celery import chain, shared_task
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import ( from reflector.db.transcripts import (
SourceKind, SourceKind,
Transcript, Transcript,
TranscriptStatus, TranscriptStatus,
TranscriptTopic,
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger from reflector.logger import logger
@@ -83,6 +87,32 @@ class PipelineMainFile(PipelineMainBase):
self.logger = logger.bind(transcript_id=self.transcript_id) self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger) self.empty_pipeline = EmptyPipeline(logger=self.logger)
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
"""Get transcript with session"""
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
@asynccontextmanager
async def lock_transaction(self):
# This lock is to prevent multiple processor starting adding
# into event array at the same time
async with asyncio.Lock():
yield
@asynccontextmanager
async def transaction(self):
async with self.lock_transaction():
async with get_session_factory()() as session:
yield session
def _handle_gather_exceptions(self, results: list, operation: str) -> None: def _handle_gather_exceptions(self, results: list, operation: str) -> None:
"""Handle exceptions from asyncio.gather with return_exceptions=True""" """Handle exceptions from asyncio.gather with return_exceptions=True"""
for i, result in enumerate(results): for i, result in enumerate(results):
@@ -97,17 +127,23 @@ class PipelineMainFile(PipelineMainBase):
@broadcast_to_sockets @broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus): async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction(): async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status) async with get_session_factory()() as session:
return await transcripts_controller.set_status(
session, transcript_id, status
)
async def process(self, file_path: Path): async def process(self, file_path: Path):
"""Main entry point for file processing""" """Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}") self.logger.info(f"Starting file pipeline for {file_path}")
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
# Clear transcript as we're going to regenerate everything # Clear transcript as we're going to regenerate everything
async with self.transaction():
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"events": [], "events": [],
@@ -131,7 +167,8 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete") self.logger.info("File pipeline complete")
await transcripts_controller.set_status(transcript.id, "ended") async with get_session_factory()() as session:
await transcripts_controller.set_status(session, transcript.id, "ended")
async def extract_and_write_audio( async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript self, file_path: Path, transcript: Transcript
@@ -308,7 +345,10 @@ class PipelineMainFile(PipelineMainBase):
async def generate_waveform(self, audio_path: Path): async def generate_waveform(self, audio_path: Path):
"""Generate and save waveform""" """Generate and save waveform"""
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = AudioWaveformProcessor( processor = AudioWaveformProcessor(
audio_path=audio_path, audio_path=audio_path,
@@ -367,7 +407,10 @@ class PipelineMainFile(PipelineMainBase):
self.logger.warning("No topics for summary generation") self.logger.warning("No topics for summary generation")
return return
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = TranscriptFinalSummaryProcessor( processor = TranscriptFinalSummaryProcessor(
transcript=transcript, transcript=transcript,
callback=self.on_long_summary, callback=self.on_long_summary,
@@ -380,17 +423,124 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush() await processor.flush()
async def on_topic(self, topic: TitleSummary):
"""Handle topic event - save to database"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
topic_obj = TranscriptTopic(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
)
await transcripts_controller.upsert_topic(session, transcript, topic_obj)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="TOPIC",
data=topic_obj,
)
async def on_title(self, data):
"""Handle title event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not transcript.title:
await transcripts_controller.update(
session,
transcript,
{"title": data.title},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_TITLE",
data={"title": data.title},
)
async def on_long_summary(self, data):
"""Handle long summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"long_summary": data.long_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data={"long_summary": data.long_summary},
)
async def on_short_summary(self, data):
"""Handle short summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"short_summary": data.short_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data={"short_summary": data.short_summary},
)
async def on_duration(self, duration):
"""Handle duration event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"duration": duration},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="DURATION",
data={"duration": duration},
)
async def on_waveform(self, waveform):
"""Handle waveform event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="WAVEFORM",
data={"waveform": waveform},
)
@shared_task @shared_task
@asynctask @asynctask
async def task_send_webhook_if_needed(*, transcript_id: str): async def task_send_webhook_if_needed(*, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured""" """Send webhook if this is a room recording with webhook configured"""
transcript = await transcripts_controller.get_by_id(transcript_id) async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
return return
if transcript.source_kind == SourceKind.ROOM and transcript.room_id: if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(transcript.room_id) room = await rooms_controller.get_by_id(session, transcript.room_id)
if room and room.webhook_url: if room and room.webhook_url:
logger.info( logger.info(
"Dispatching webhook", "Dispatching webhook",
@@ -407,8 +557,8 @@ async def task_send_webhook_if_needed(*, transcript_id: str):
@asynctask @asynctask
async def task_pipeline_file_process(*, transcript_id: str): async def task_pipeline_file_process(*, transcript_id: str):
"""Celery task for file pipeline processing""" """Celery task for file pipeline processing"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")

View File

@@ -20,9 +20,11 @@ import av
import boto3 import boto3
from celery import chord, current_task, group, shared_task from celery import chord, current_task, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from structlog import BoundLogger as Logger from structlog import BoundLogger as Logger
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.meetings import meeting_consent_controller, meetings_controller from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -96,9 +98,10 @@ def get_transcript(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(**kwargs): async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id") transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception("Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
# Enhanced logger with Celery task context # Enhanced logger with Celery task context
tlogger = logger.bind(transcript_id=transcript.id) tlogger = logger.bind(transcript_id=transcript.id)
@@ -139,10 +142,14 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
self._ws_manager = get_ws_manager() self._ws_manager = get_ws_manager()
return self._ws_manager return self._ws_manager
async def get_transcript(self) -> Transcript: async def get_transcript(self, session: AsyncSession = None) -> Transcript:
# fetch the transcript # fetch the transcript
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id( result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id session, self.transcript_id
) )
if not result: if not result:
raise Exception("Transcript not found") raise Exception("Transcript not found")
@@ -175,8 +182,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
async with self.lock_transaction(): async with self.lock_transaction():
async with transcripts_controller.transaction(): async with get_session_factory()() as session:
yield yield session
@broadcast_to_sockets @broadcast_to_sockets
async def on_status(self, status): async def on_status(self, status):
@@ -207,13 +214,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
# when the status of the pipeline changes, update the transcript # when the status of the pipeline changes, update the transcript
async with self._lock: async with self._lock:
return await transcripts_controller.set_status(self.transcript_id, status) async with get_session_factory()() as session:
return await transcripts_controller.set_status(
session, self.transcript_id, status
)
@broadcast_to_sockets @broadcast_to_sockets
async def on_transcript(self, data): async def on_transcript(self, data):
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="TRANSCRIPT", event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation), data=TranscriptText(text=data.text, translation=data.translation),
@@ -230,10 +241,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
) )
if isinstance(data, TitleSummaryWithIdProcessorType): if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id topic.id = data.id
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.upsert_topic(transcript, topic) await transcripts_controller.upsert_topic(session, transcript, topic)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="TOPIC", event="TOPIC",
data=topic, data=topic,
@@ -242,16 +254,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_title(self, data): async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title) final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
if not transcript.title: if not transcript.title:
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"title": final_title.title, "title": final_title.title,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_TITLE", event="FINAL_TITLE",
data=final_title, data=final_title,
@@ -260,15 +274,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_long_summary(self, data): async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"long_summary": final_long_summary.long_summary, "long_summary": final_long_summary.long_summary,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_LONG_SUMMARY", event="FINAL_LONG_SUMMARY",
data=final_long_summary, data=final_long_summary,
@@ -279,15 +295,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
final_short_summary = TranscriptFinalShortSummary( final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary short_summary=data.short_summary
) )
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"short_summary": final_short_summary.short_summary, "short_summary": final_short_summary.short_summary,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_SHORT_SUMMARY", event="FINAL_SHORT_SUMMARY",
data=final_short_summary, data=final_short_summary,
@@ -295,29 +313,30 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_duration(self, data): async def on_duration(self, data):
async with self.transaction(): async with self.transaction() as session:
duration = TranscriptDuration(duration=data) duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"duration": duration.duration, "duration": duration.duration,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration session, transcript=transcript, event="DURATION", data=duration
) )
@broadcast_to_sockets @broadcast_to_sockets
async def on_waveform(self, data): async def on_waveform(self, data):
async with self.transaction(): async with self.transaction() as session:
waveform = TranscriptWaveform(waveform=data) waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform session, transcript=transcript, event="WAVEFORM", data=waveform
) )
@@ -535,7 +554,8 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
return return
# Upload to external storage and delete the file # Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(transcript) async with get_session_factory()() as session:
await transcripts_controller.move_mp3_to_storage(session, transcript)
logger.info("Upload mp3 done") logger.info("Upload mp3 done")
@@ -572,12 +592,19 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
recording = None recording = None
try: try:
if transcript.recording_id: if transcript.recording_id:
recording = await recordings_controller.get_by_id(transcript.recording_id) async with get_session_factory()() as session:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
if recording and recording.meeting_id: if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id(recording.meeting_id) meeting = await meetings_controller.get_by_id(
session, recording.meeting_id
)
if meeting: if meeting:
consent_denied = await meeting_consent_controller.has_any_denial( consent_denied = (
meeting.id await meeting_consent_controller.has_any_denial(
session, meeting.id
)
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get fetch consent: {e}", exc_info=e) logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
@@ -606,7 +633,10 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e) logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible # non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(transcript, {"audio_deleted": True}) async with get_session_factory()() as session:
await transcripts_controller.update(
session, transcript, {"audio_deleted": True}
)
# 2. Delete processed audio from transcript storage S3 bucket # 2. Delete processed audio from transcript storage S3 bucket
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
storage = get_transcripts_storage() storage = get_transcripts_storage()
@@ -638,7 +668,10 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
logger.info("Transcript has no recording") logger.info("Transcript has no recording")
return return
recording = await recordings_controller.get_by_id(transcript.recording_id) async with get_session_factory()() as session:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
if not recording: if not recording:
logger.info("Recording not found") logger.info("Recording not found")
return return
@@ -647,12 +680,12 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
logger.info("Recording has no meeting") logger.info("Recording has no meeting")
return return
meeting = await meetings_controller.get_by_id(recording.meeting_id) meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
if not meeting: if not meeting:
logger.info("No meeting found for this recording") logger.info("No meeting found for this recording")
return return
room = await rooms_controller.get_by_id(meeting.room_id) room = await rooms_controller.get_by_id(session, meeting.room_id)
if not room: if not room:
logger.error(f"Missing room for a meeting {meeting.id}") logger.error(f"Missing room for a meeting {meeting.id}")
return return
@@ -677,8 +710,9 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
response = await send_message_to_zulip( response = await send_message_to_zulip(
room.zulip_stream, room.zulip_topic, message room.zulip_stream, room.zulip_topic, message
) )
async with get_session_factory()() as session:
await transcripts_controller.update( await transcripts_controller.update(
transcript, {"zulip_message_id": response["id"]} session, transcript, {"zulip_message_id": response["id"]}
) )
logger.info("Posted to zulip") logger.info("Posted to zulip")

View File

@@ -8,9 +8,10 @@ from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel from pydantic import BaseModel
from redis.exceptions import LockError from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session_factory from reflector.db import get_session, get_session_factory
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -185,7 +186,7 @@ async def rooms_list(
session_factory = get_session_factory() session_factory = get_session_factory()
async with session_factory() as session: async with session_factory() as session:
query = await rooms_controller.get_all( query = await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True session, user_id=user_id, order_by="-created_at", return_query=True
) )
return await paginate(session, query) return await paginate(session, query)
@@ -194,9 +195,10 @@ async def rooms_list(
async def rooms_get( async def rooms_get(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id) room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
return room return room
@@ -206,9 +208,10 @@ async def rooms_get(
async def rooms_get_by_name( async def rooms_get_by_name(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -230,10 +233,12 @@ async def rooms_get_by_name(
async def rooms_create( async def rooms_create(
room: CreateRoom, room: CreateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await rooms_controller.add( return await rooms_controller.add(
session,
name=room.name, name=room.name,
user_id=user_id, user_id=user_id,
zulip_auto_post=room.zulip_auto_post, zulip_auto_post=room.zulip_auto_post,
@@ -257,13 +262,14 @@ async def rooms_update(
room_id: str, room_id: str,
info: UpdateRoom, info: UpdateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id) room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
values = info.dict(exclude_unset=True) values = info.dict(exclude_unset=True)
await rooms_controller.update(room, values) await rooms_controller.update(session, room, values)
return room return room
@@ -271,12 +277,13 @@ async def rooms_update(
async def rooms_delete( async def rooms_delete(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id(room_id, user_id=user_id) room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
await rooms_controller.remove_by_id(room.id, user_id=user_id) await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")
@@ -285,9 +292,10 @@ async def rooms_create_meeting(
room_name: str, room_name: str,
info: CreateRoomMeeting, info: CreateRoomMeeting,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -303,7 +311,7 @@ async def rooms_create_meeting(
meeting = None meeting = None
if not info.allow_duplicated: if not info.allow_duplicated:
meeting = await meetings_controller.get_active( meeting = await meetings_controller.get_active(
room=room, current_time=current_time session, room=room, current_time=current_time
) )
if meeting is None: if meeting is None:
@@ -314,6 +322,7 @@ async def rooms_create_meeting(
await upload_logo(whereby_meeting["roomName"], "./images/logo.png") await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
meeting = await meetings_controller.create( meeting = await meetings_controller.create(
session,
id=whereby_meeting["meetingId"], id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"], room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"], room_url=whereby_meeting["roomUrl"],
@@ -340,11 +349,12 @@ async def rooms_create_meeting(
async def rooms_test_webhook( async def rooms_test_webhook(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
"""Test webhook configuration by sending a sample payload.""" """Test webhook configuration by sending a sample payload."""
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id(room_id) room = await rooms_controller.get_by_id(session, room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -361,9 +371,10 @@ async def rooms_test_webhook(
async def rooms_sync_ics( async def rooms_sync_ics(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -390,9 +401,10 @@ async def rooms_sync_ics(
async def rooms_ics_status( async def rooms_ics_status(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -407,7 +419,7 @@ async def rooms_ics_status(
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval) next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False session, room.id, include_deleted=False
) )
return ICSStatus( return ICSStatus(
@@ -423,15 +435,16 @@ async def rooms_ics_status(
async def rooms_list_meetings( async def rooms_list_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False session, room.id, include_deleted=False
) )
if user_id != room.user_id: if user_id != room.user_id:
@@ -449,15 +462,16 @@ async def rooms_list_upcoming_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
minutes_ahead: int = 120, minutes_ahead: int = 120,
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_upcoming( events = await calendar_events_controller.get_upcoming(
room.id, minutes_ahead=minutes_ahead session, room.id, minutes_ahead=minutes_ahead
) )
if user_id != room.user_id: if user_id != room.user_id:
@@ -472,16 +486,17 @@ async def rooms_list_upcoming_meetings(
async def rooms_list_active_meetings( async def rooms_list_active_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
meetings = await meetings_controller.get_all_active_for_room( meetings = await meetings_controller.get_all_active_for_room(
room=room, current_time=current_time session, room=room, current_time=current_time
) )
# Hide host URLs from non-owners # Hide host URLs from non-owners
@@ -497,15 +512,16 @@ async def rooms_get_meeting(
room_name: str, room_name: str,
meeting_id: str, meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
"""Get a single meeting by ID within a specific room.""" """Get a single meeting by ID within a specific room."""
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(meeting_id) meeting = await meetings_controller.get_by_id(session, meeting_id)
if not meeting: if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")
@@ -525,14 +541,15 @@ async def rooms_join_meeting(
room_name: str, room_name: str,
meeting_id: str, meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(meeting_id) meeting = await meetings_controller.get_by_id(session, meeting_id)
if not meeting: if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")

View File

@@ -9,8 +9,10 @@ from typing import Annotated, Optional
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import AudioWaveform, transcripts_controller from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM from reflector.views.transcripts import ALGORITHM
@@ -48,7 +50,7 @@ async def transcript_get_audio_mp3(
raise unauthorized_exception raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
@@ -96,10 +98,11 @@ async def transcript_get_audio_mp3(
async def transcript_get_audio_waveform( async def transcript_get_audio_waveform(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> AudioWaveform: ) -> AudioWaveform:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if not transcript.audio_waveform_filename.exists(): if not transcript.audio_waveform_filename.exists():

View File

@@ -8,8 +8,10 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus from reflector.views.types import DeletionStatus
@@ -37,10 +39,11 @@ class UpdateParticipant(BaseModel):
async def transcript_get_participants( async def transcript_get_participants(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> list[Participant]: ) -> list[Participant]:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.participants is None: if transcript.participants is None:
@@ -57,10 +60,11 @@ async def transcript_add_participant(
transcript_id: str, transcript_id: str,
participant: CreateParticipant, participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
# ensure the speaker is unique # ensure the speaker is unique
@@ -83,10 +87,11 @@ async def transcript_get_participant(
transcript_id: str, transcript_id: str,
participant_id: str, participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
for p in transcript.participants: for p in transcript.participants:
@@ -102,10 +107,11 @@ async def transcript_update_participant(
participant_id: str, participant_id: str,
participant: UpdateParticipant, participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
# ensure the speaker is unique # ensure the speaker is unique
@@ -139,10 +145,11 @@ async def transcript_delete_participant(
transcript_id: str, transcript_id: str,
participant_id: str, participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> DeletionStatus: ) -> DeletionStatus:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
await transcripts_controller.delete_participant(transcript, participant_id) await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")

View File

@@ -3,8 +3,10 @@ from typing import Annotated, Optional
import celery import celery
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -19,10 +21,11 @@ class ProcessStatus(BaseModel):
async def transcript_process( async def transcript_process(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:

View File

@@ -8,8 +8,10 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
router = APIRouter() router = APIRouter()
@@ -36,10 +38,11 @@ async def transcript_assign_speaker(
transcript_id: str, transcript_id: str,
assignment: SpeakerAssignment, assignment: SpeakerAssignment,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus: ) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if not transcript: if not transcript:
@@ -100,6 +103,7 @@ async def transcript_assign_speaker(
for topic in changed_topics: for topic in changed_topics:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": transcript.topics_dump(), "topics": transcript.topics_dump(),
@@ -114,10 +118,11 @@ async def transcript_merge_speaker(
transcript_id: str, transcript_id: str,
merge: SpeakerMerge, merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus: ) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if not transcript: if not transcript:
@@ -163,6 +168,7 @@ async def transcript_merge_speaker(
for topic in changed_topics: for topic in changed_topics:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": transcript.topics_dump(), "topics": transcript.topics_dump(),

View File

@@ -3,8 +3,10 @@ from typing import Annotated, Optional
import av import av
from fastapi import APIRouter, Depends, HTTPException, UploadFile from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -22,10 +24,11 @@ async def transcript_record_upload(
total_chunks: int, total_chunks: int,
chunk: UploadFile, chunk: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:
@@ -89,7 +92,7 @@ async def transcript_record_upload(
container.close() container.close()
# set the status to "uploaded" # set the status to "uploaded"
await transcripts_controller.update(transcript, {"status": "uploaded"}) await transcripts_controller.update(session, transcript, {"status": "uploaded"})
# launch a background task to process the file # launch a background task to process the file
task_pipeline_file_process.delay(transcript_id=transcript_id) task_pipeline_file_process.delay(transcript_id=transcript_id)

View File

@@ -1,8 +1,10 @@
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base from .rtc_offer import RtcOffer, rtc_offer_base
@@ -16,10 +18,11 @@ async def transcript_record_webrtc(
params: RtcOffer, params: RtcOffer,
request: Request, request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:

View File

@@ -24,7 +24,7 @@ async def transcript_events_websocket(
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
): ):
# user_id = user["sub"] if user else None # user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -1,6 +1,4 @@
import os import os
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest import pytest
@@ -33,6 +31,17 @@ def docker_compose_file(pytestconfig):
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml") return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
@pytest.fixture(scope="session")
def docker_ip():
"""Get Docker IP address for test services"""
# For most Docker setups, localhost works
return "127.0.0.1"
# Only register docker_services dependent fixtures if docker plugin is available
try:
import pytest_docker # noqa: F401
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services): def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive.""" """Ensure that PostgreSQL service is up and responsive."""
@@ -54,362 +63,252 @@ def postgres_service(docker_ip, docker_services):
except Exception: except Exception:
return False return False
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive) docker_services.wait_until_responsive(
timeout=30.0, pause=0.1, check=is_responsive
)
# Return connection parameters # Return connection parameters
return { return {
"host": docker_ip, "host": docker_ip,
"port": port, "port": port,
"dbname": "reflector_test", "database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
except ImportError:
# Docker plugin not available, provide a dummy fixture
@pytest.fixture(scope="session")
def postgres_service(docker_ip):
"""Dummy postgres service when docker plugin is not available"""
return {
"host": docker_ip,
"port": 15432, # Default test postgres port
"database": "reflector_test",
"user": "test_user", "user": "test_user",
"password": "test_password", "password": "test_password",
} }
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="session", autouse=True)
@pytest.mark.asyncio
async def setup_database(postgres_service): async def setup_database(postgres_service):
from reflector.db import get_engine """Setup database and run migrations"""
from reflector.db.base import metadata from sqlalchemy.ext.asyncio import create_async_engine
async_engine = get_engine() from reflector.db import Base
async with async_engine.begin() as conn: # Build database URL from connection params
await conn.run_sync(metadata.drop_all) db_config = postgres_service
await conn.run_sync(metadata.create_all) DATABASE_URL = (
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
f"@{db_config['host']}:{db_config['port']}/{db_config['database']}"
)
# Override settings
from reflector.settings import settings
settings.DATABASE_URL = DATABASE_URL
# Create engine and tables
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
# Drop all tables first to ensure clean state
await conn.run_sync(Base.metadata.drop_all)
# Create all tables
await conn.run_sync(Base.metadata.create_all)
try:
yield yield
finally:
await async_engine.dispose() # Cleanup
await engine.dispose()
@pytest.fixture @pytest.fixture
async def session(): async def session(setup_database):
"""Provide a transactional database session for tests"""
from reflector.db import get_session_factory from reflector.db import get_session_factory
async with get_session_factory()() as session: async with get_session_factory()() as session:
yield session yield session
await session.rollback()
@pytest.fixture @pytest.fixture
def dummy_processors(): def fake_mp3_upload(tmp_path):
with ( """Create a temporary MP3 file for upload testing"""
patch( mp3_file = tmp_path / "test.mp3"
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic" # Create a minimal valid MP3 file (ID3v2 header + minimal frame)
) as mock_topic, mp3_data = b"ID3\x04\x00\x00\x00\x00\x00\x00" + b"\xff\xfb" + b"\x00" * 100
patch( mp3_file.write_bytes(mp3_data)
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title" return mp3_file
) as mock_title,
patch(
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_long_summary"
) as mock_long_summary,
patch(
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
) as mock_short_summary,
):
from reflector.processors.transcript_topic_detector import TopicResponse
mock_topic.return_value = TopicResponse(
title="LLM TITLE", summary="LLM SUMMARY"
)
mock_title.return_value = "LLM Title"
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = "LLM SHORT SUMMARY"
yield (
mock_topic,
mock_title,
mock_long_summary,
mock_short_summary,
) # noqa
@pytest.fixture @pytest.fixture
async def whisper_transcript(): def dummy_transcript():
from reflector.processors.audio_transcript_whisper import ( """Mock transcript processor response"""
AudioTranscriptWhisperProcessor,
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = AudioTranscriptWhisperProcessor()
yield
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
_time_idx = 0
async def _transcript(self, data: AudioFile):
i = self._time_idx
self._time_idx += 2
return Transcript(
text="Hello world.",
words=[
Word(start=i, end=i + 1, text="Hello", speaker=0),
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture
async def dummy_diarization():
from reflector.processors.audio_diarization import AudioDiarizationProcessor
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
_time_idx = 0
async def _diarize(self, data):
i = self._time_idx
self._time_idx += 2
return [
{"start": i, "end": i + 1, "speaker": 0},
{"start": i + 1, "end": i + 2, "speaker": 1},
]
with patch(
"reflector.processors.audio_diarization_auto"
".AudioDiarizationAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioDiarizationProcessor()
yield
@pytest.fixture
async def dummy_file_transcript():
from reflector.processors.file_transcript import FileTranscriptProcessor
from reflector.processors.types import Transcript, Word from reflector.processors.types import Transcript, Word
class TestFileTranscriptProcessor(FileTranscriptProcessor):
async def _transcript(self, data):
return Transcript( return Transcript(
text="Hello world. How are you today?", text="Hello world this is a test",
words=[ words=[
Word(start=0.0, end=0.5, text="Hello", speaker=0), Word(word="Hello", start=0.0, end=0.5, speaker=0),
Word(start=0.5, end=0.6, text=" ", speaker=0), Word(word="world", start=0.5, end=1.0, speaker=0),
Word(start=0.6, end=1.0, text="world", speaker=0), Word(word="this", start=1.0, end=1.5, speaker=0),
Word(start=1.0, end=1.1, text=".", speaker=0), Word(word="is", start=1.5, end=1.8, speaker=0),
Word(start=1.1, end=1.2, text=" ", speaker=0), Word(word="a", start=1.8, end=2.0, speaker=0),
Word(start=1.2, end=1.5, text="How", speaker=0), Word(word="test", start=2.0, end=2.5, speaker=0),
Word(start=1.5, end=1.6, text=" ", speaker=0),
Word(start=1.6, end=1.8, text="are", speaker=0),
Word(start=1.8, end=1.9, text=" ", speaker=0),
Word(start=1.9, end=2.1, text="you", speaker=0),
Word(start=2.1, end=2.2, text=" ", speaker=0),
Word(start=2.2, end=2.5, text="today", speaker=0),
Word(start=2.5, end=2.6, text="?", speaker=0),
], ],
) )
with patch(
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__" @pytest.fixture
) as mock_auto: def dummy_transcript_translator():
mock_auto.return_value = TestFileTranscriptProcessor() """Mock transcript translation"""
yield return "Hola mundo esto es una prueba"
@pytest.fixture @pytest.fixture
async def dummy_file_diarization(): def dummy_diarization():
from reflector.processors.file_diarization import ( """Mock diarization processor response"""
FileDiarizationOutput, from reflector.processors.types import DiarizationOutput, DiarizationSegment
FileDiarizationProcessor,
)
from reflector.processors.types import DiarizationSegment
class TestFileDiarizationProcessor(FileDiarizationProcessor): return DiarizationOutput(
async def _diarize(self, data):
return FileDiarizationOutput(
diarization=[ diarization=[
DiarizationSegment(start=0.0, end=1.1, speaker=0), DiarizationSegment(speaker=0, start=0.0, end=1.0),
DiarizationSegment(start=1.2, end=2.6, speaker=1), DiarizationSegment(speaker=1, start=1.0, end=2.5),
] ]
) )
with patch(
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__" @pytest.fixture
) as mock_auto: def dummy_file_transcript():
mock_auto.return_value = TestFileDiarizationProcessor() """Mock file transcript processor response"""
yield from reflector.processors.types import Transcript, Word
return Transcript(
text="This is a complete file transcript with multiple speakers",
words=[
Word(word="This", start=0.0, end=0.5, speaker=0),
Word(word="is", start=0.5, end=0.8, speaker=0),
Word(word="a", start=0.8, end=1.0, speaker=0),
Word(word="complete", start=1.0, end=1.5, speaker=1),
Word(word="file", start=1.5, end=1.8, speaker=1),
Word(word="transcript", start=1.8, end=2.3, speaker=1),
Word(word="with", start=2.3, end=2.5, speaker=0),
Word(word="multiple", start=2.5, end=3.0, speaker=0),
Word(word="speakers", start=3.0, end=3.5, speaker=0),
],
)
@pytest.fixture @pytest.fixture
async def dummy_transcript_translator(): def dummy_file_diarization():
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor """Mock file diarization processor response"""
from reflector.processors.types import DiarizationOutput, DiarizationSegment
class TestTranscriptTranslatorProcessor(TranscriptTranslatorProcessor): return DiarizationOutput(
async def _translate(self, text: str) -> str: diarization=[
source_language = self.get_pref("audio:source_language", "en") DiarizationSegment(speaker=0, start=0.0, end=1.0),
target_language = self.get_pref("audio:target_language", "en") DiarizationSegment(speaker=1, start=1.0, end=2.3),
return f"{source_language}:{target_language}:{text}" DiarizationSegment(speaker=0, start=2.3, end=3.5),
]
)
def mock_new(cls, *args, **kwargs):
return TestTranscriptTranslatorProcessor(*args, **kwargs)
with patch( @pytest.fixture
"reflector.processors.transcript_translator_auto" def fake_transcript_with_topics():
".TranscriptTranslatorAutoProcessor.__new__", """Create a transcript with topics for testing"""
mock_new, from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word
topics = [
TranscriptTopic(
id="topic1",
title="Introduction",
summary="Opening remarks and introductions",
timestamp=0.0,
duration=30.0,
words=[
Word(word="Hello", start=0.0, end=0.5, speaker=0),
Word(word="everyone", start=0.5, end=1.0, speaker=0),
],
),
TranscriptTopic(
id="topic2",
title="Main Discussion",
summary="Core topics and key points",
timestamp=30.0,
duration=60.0,
words=[
Word(word="Let's", start=30.0, end=30.3, speaker=1),
Word(word="discuss", start=30.3, end=30.8, speaker=1),
Word(word="the", start=30.8, end=31.0, speaker=1),
Word(word="agenda", start=31.0, end=31.5, speaker=1),
],
),
]
return topics
@pytest.fixture
def dummy_processors(
dummy_transcript,
dummy_transcript_translator,
dummy_diarization,
dummy_file_transcript,
dummy_file_diarization,
): ):
yield """Mock all processor responses"""
return {
"transcript": dummy_transcript,
@pytest.fixture "translator": dummy_transcript_translator,
async def dummy_llm(): "diarization": dummy_diarization,
from reflector.llm import LLM "file_transcript": dummy_file_transcript,
"file_diarization": dummy_file_diarization,
class TestLLM(LLM):
def __init__(self):
self.model_name = "DUMMY MODEL"
self.llm_tokenizer = "DUMMY TOKENIZER"
# LLM doesn't have get_instance anymore, mocking constructor instead
with patch("reflector.llm.LLM") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.fixture
async def dummy_storage():
from reflector.storage.base import Storage
class DummyStorage(Storage):
async def _put_file(self, *args, **kwargs):
pass
async def _delete_file(self, *args, **kwargs):
pass
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
async def _get_file(self, *args, **kwargs):
from pathlib import Path
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
return test_mp3.read_bytes()
dummy = DummyStorage()
with (
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
patch(
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
) as mock_get_transcripts2,
):
mock_storage.return_value = dummy
mock_get_transcripts.return_value = dummy
mock_get_transcripts2.return_value = dummy
yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
} }
@pytest.fixture(scope="session") @pytest.fixture
def celery_includes(): def dummy_storage():
return [ """Mock storage backend"""
"reflector.pipelines.main_live_pipeline", from unittest.mock import AsyncMock
"reflector.pipelines.main_file_pipeline",
] storage = AsyncMock()
storage.get_file_url.return_value = "https://example.com/test-audio.mp3"
storage.put_file.return_value = None
storage.delete_file.return_value = None
return storage
@pytest.fixture @pytest.fixture
async def client(): def dummy_llm():
from httpx import AsyncClient """Mock LLM responses"""
return {
from reflector.app import app "title": "Test Meeting Title",
"summary": "This is a test meeting summary with key discussion points.",
async with AsyncClient(app=app, base_url="http://test/v1") as ac: "short_summary": "Brief test summary.",
yield ac }
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
) as mock_move:
mock_move.return_value = True
yield
@pytest.fixture @pytest.fixture
async def fake_transcript_with_topics(tmpdir, client): def whisper_transcript():
import shutil """Mock Whisper API response format"""
from pathlib import Path return {
"text": "Hello world this is a test",
from reflector.db.transcripts import TranscriptTopic "segments": [
from reflector.processors.types import Word {
from reflector.settings import settings "start": 0.0,
from reflector.views.transcripts import transcripts_controller "end": 2.5,
"text": "Hello world this is a test",
settings.DATA_DIR = Path(tmpdir) "words": [
{"word": "Hello", "start": 0.0, "end": 0.5},
# create a transcript {"word": "world", "start": 0.5, "end": 1.0},
response = await client.post("/transcripts", json={"name": "Test audio download"}) {"word": "this", "start": 1.0, "end": 1.5},
assert response.status_code == 200 {"word": "is", "start": 1.5, "end": 1.8},
tid = response.json()["id"] {"word": "a", "start": 1.8, "end": 2.0},
{"word": "test", "start": 2.0, "end": 2.5},
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename)
# create some topics
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Topic 1",
summary="Topic 1 summary",
timestamp=0,
transcript="Hello world",
words=[
Word(text="Hello", start=0, end=1, speaker=0),
Word(text="world", start=1, end=2, speaker=0),
], ],
), }
)
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Topic 2",
summary="Topic 2 summary",
timestamp=2,
transcript="Hello world",
words=[
Word(text="Hello", start=2, end=3, speaker=0),
Word(text="world", start=3, end=4, speaker=0),
], ],
), "language": "en",
) }
yield transcript

View File

@@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.services.ics_sync import ICSSyncService from reflector.services.ics_sync import ICSSyncService
@@ -17,6 +18,7 @@ async def test_attendee_parsing_bug():
instead of properly parsed email addresses. instead of properly parsed email addresses.
""" """
# Create a test room # Create a test room
async with get_session_factory()() as session:
room = await rooms_controller.add( room = await rooms_controller.add(
session, session,
name="test-room", name="test-room",
@@ -95,99 +97,14 @@ async def test_attendee_parsing_bug():
# This is where the bug manifests - check the attendees # This is where the bug manifests - check the attendees
attendees = event["attendees"] attendees = event["attendees"]
# Print attendee info for debugging # Debug output to see what's happening
print(f"Number of attendees found: {len(attendees)}") print(f"Number of attendees: {len(attendees)}")
for i, attendee in enumerate(attendees): for i, attendee in enumerate(attendees):
print( print(f"Attendee {i}: {attendee}")
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
)
# With the fix, we should now get properly parsed email addresses # The bug would cause 29 attendees (length of "MAILIN01234567890@allo.coop")
# Check that no single characters are parsed as emails # instead of 1 attendee
single_char_emails = [ assert len(attendees) == 1, f"Expected 1 attendee, got {len(attendees)}"
att for att in attendees if att.get("email") and len(att["email"]) == 1
]
if single_char_emails: # Verify the single attendee has correct email
print( assert attendees[0]["email"] == "MAILIN01234567890@allo.coop"
f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
)
for att in single_char_emails:
print(f" - '{att['email']}'")
# Should have attendees but not single-character emails
assert len(attendees) > 0
assert (
len(single_char_emails) == 0
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
# Check that all emails are valid (contain @ symbol)
valid_emails = [
att for att in attendees if att.get("email") and "@" in att["email"]
]
assert len(valid_emails) == len(
attendees
), "Some attendees don't have valid email addresses"
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
assert (
len(attendees) >= 25
), f"Expected around 29 attendees, got {len(attendees)}"
@pytest.mark.asyncio
async def test_correct_attendee_parsing():
"""
Test what correct attendee parsing should look like.
"""
from datetime import datetime, timezone
from icalendar import Event
from reflector.services.ics_sync import ICSFetchService
service = ICSFetchService()
# Create a properly formatted event with multiple attendees
event = Event()
event.add("uid", "test-correct-attendees")
event.add("summary", "Test Meeting")
event.add("location", "http://test.com/test")
event.add("dtstart", datetime.now(timezone.utc))
event.add("dtend", datetime.now(timezone.utc))
# Add attendees the correct way (separate ATTENDEE lines)
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
event.add(
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
)
# Parse the event
result = service._parse_event(event)
assert result is not None
attendees = result["attendees"]
# Should have 4 attendees (3 attendees + 1 organizer)
assert len(attendees) == 4
# Check that all emails are valid email addresses
emails = [att["email"] for att in attendees if att.get("email")]
expected_emails = [
"alice@example.com",
"bob@example.com",
"charlie@example.com",
"organizer@example.com",
]
for email in emails:
assert "@" in email, f"Invalid email format: {email}"
assert len(email) > 5, f"Email too short: {email}"
# Check that we have the expected emails
assert "alice@example.com" in emails
assert "bob@example.com" in emails
assert "charlie@example.com" in emails
assert "organizer@example.com" in emails

View File

@@ -8,6 +8,7 @@ from reflector.db.transcripts import (
SourceKind, SourceKind,
TranscriptController, TranscriptController,
TranscriptTopic, TranscriptTopic,
transcripts_controller,
) )
from reflector.processors.types import Word from reflector.processors.types import Word