fix: Complete SQLAlchemy 2.0 migration - add session parameters to all controller calls

- Add session parameter to all view functions and controller calls
- Fix pipeline files to use get_session_factory() for background tasks
- Update PipelineMainBase and PipelineMainFile to handle sessions properly
- Add missing on_* methods to PipelineMainFile class
- Fix test fixtures to handle docker services availability
- Add docker_ip fixture for test database connections
- Import fixes for transcripts_controller in tests

All controller calls now properly use sessions as first parameter per SQLAlchemy 2.0 async patterns.
This commit is contained in:
2025-09-18 13:08:19 -06:00
parent 45d1608950
commit d21b65e4e8
13 changed files with 593 additions and 550 deletions

View File

@@ -8,18 +8,22 @@ Uses parallel processing for transcription, diarization, and waveform generation
import asyncio import asyncio
import uuid import uuid
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
import av import av
import structlog import structlog
from celery import chain, shared_task from celery import chain, shared_task
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import ( from reflector.db.transcripts import (
SourceKind, SourceKind,
Transcript, Transcript,
TranscriptStatus, TranscriptStatus,
TranscriptTopic,
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger from reflector.logger import logger
@@ -83,6 +87,32 @@ class PipelineMainFile(PipelineMainBase):
self.logger = logger.bind(transcript_id=self.transcript_id) self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger) self.empty_pipeline = EmptyPipeline(logger=self.logger)
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
"""Get transcript with session"""
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
@asynccontextmanager
async def lock_transaction(self):
# This lock is to prevent multiple processor starting adding
# into event array at the same time
async with asyncio.Lock():
yield
@asynccontextmanager
async def transaction(self):
async with self.lock_transaction():
async with get_session_factory()() as session:
yield session
def _handle_gather_exceptions(self, results: list, operation: str) -> None: def _handle_gather_exceptions(self, results: list, operation: str) -> None:
"""Handle exceptions from asyncio.gather with return_exceptions=True""" """Handle exceptions from asyncio.gather with return_exceptions=True"""
for i, result in enumerate(results): for i, result in enumerate(results):
@@ -97,17 +127,23 @@ class PipelineMainFile(PipelineMainBase):
@broadcast_to_sockets @broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus): async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction(): async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status) async with get_session_factory()() as session:
return await transcripts_controller.set_status(
session, transcript_id, status
)
async def process(self, file_path: Path): async def process(self, file_path: Path):
"""Main entry point for file processing""" """Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}") self.logger.info(f"Starting file pipeline for {file_path}")
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
# Clear transcript as we're going to regenerate everything # Clear transcript as we're going to regenerate everything
async with self.transaction():
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"events": [], "events": [],
@@ -131,7 +167,8 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete") self.logger.info("File pipeline complete")
await transcripts_controller.set_status(transcript.id, "ended") async with get_session_factory()() as session:
await transcripts_controller.set_status(session, transcript.id, "ended")
async def extract_and_write_audio( async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript self, file_path: Path, transcript: Transcript
@@ -308,7 +345,10 @@ class PipelineMainFile(PipelineMainBase):
async def generate_waveform(self, audio_path: Path): async def generate_waveform(self, audio_path: Path):
"""Generate and save waveform""" """Generate and save waveform"""
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = AudioWaveformProcessor( processor = AudioWaveformProcessor(
audio_path=audio_path, audio_path=audio_path,
@@ -367,7 +407,10 @@ class PipelineMainFile(PipelineMainBase):
self.logger.warning("No topics for summary generation") self.logger.warning("No topics for summary generation")
return return
transcript = await self.get_transcript() async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = TranscriptFinalSummaryProcessor( processor = TranscriptFinalSummaryProcessor(
transcript=transcript, transcript=transcript,
callback=self.on_long_summary, callback=self.on_long_summary,
@@ -380,37 +423,144 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush() await processor.flush()
async def on_topic(self, topic: TitleSummary):
"""Handle topic event - save to database"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
topic_obj = TranscriptTopic(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
)
await transcripts_controller.upsert_topic(session, transcript, topic_obj)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="TOPIC",
data=topic_obj,
)
async def on_title(self, data):
"""Handle title event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not transcript.title:
await transcripts_controller.update(
session,
transcript,
{"title": data.title},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_TITLE",
data={"title": data.title},
)
async def on_long_summary(self, data):
"""Handle long summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"long_summary": data.long_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data={"long_summary": data.long_summary},
)
async def on_short_summary(self, data):
"""Handle short summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"short_summary": data.short_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data={"short_summary": data.short_summary},
)
async def on_duration(self, duration):
"""Handle duration event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"duration": duration},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="DURATION",
data={"duration": duration},
)
async def on_waveform(self, waveform):
"""Handle waveform event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="WAVEFORM",
data={"waveform": waveform},
)
@shared_task @shared_task
@asynctask @asynctask
async def task_send_webhook_if_needed(*, transcript_id: str): async def task_send_webhook_if_needed(*, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured""" """Send webhook if this is a room recording with webhook configured"""
transcript = await transcripts_controller.get_by_id(transcript_id) async with get_session_factory()() as session:
if not transcript: transcript = await transcripts_controller.get_by_id(session, transcript_id)
return if not transcript:
return
if transcript.source_kind == SourceKind.ROOM and transcript.room_id: if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(transcript.room_id) room = await rooms_controller.get_by_id(session, transcript.room_id)
if room and room.webhook_url: if room and room.webhook_url:
logger.info( logger.info(
"Dispatching webhook", "Dispatching webhook",
transcript_id=transcript_id, transcript_id=transcript_id,
room_id=room.id, room_id=room.id,
webhook_url=room.webhook_url, webhook_url=room.webhook_url,
) )
send_transcript_webhook.delay( send_transcript_webhook.delay(
transcript_id, room.id, event_id=uuid.uuid4().hex transcript_id, room.id, event_id=uuid.uuid4().hex
) )
@shared_task @shared_task
@asynctask @asynctask
async def task_pipeline_file_process(*, transcript_id: str): async def task_pipeline_file_process(*, transcript_id: str):
"""Celery task for file pipeline processing""" """Celery task for file pipeline processing"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
pipeline = PipelineMainFile(transcript_id=transcript_id) pipeline = PipelineMainFile(transcript_id=transcript_id)
try: try:

View File

@@ -20,9 +20,11 @@ import av
import boto3 import boto3
from celery import chord, current_task, group, shared_task from celery import chord, current_task, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from structlog import BoundLogger as Logger from structlog import BoundLogger as Logger
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.meetings import meeting_consent_controller, meetings_controller from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -96,9 +98,10 @@ def get_transcript(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(**kwargs): async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id") transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception("Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
# Enhanced logger with Celery task context # Enhanced logger with Celery task context
tlogger = logger.bind(transcript_id=transcript.id) tlogger = logger.bind(transcript_id=transcript.id)
@@ -139,11 +142,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
self._ws_manager = get_ws_manager() self._ws_manager = get_ws_manager()
return self._ws_manager return self._ws_manager
async def get_transcript(self) -> Transcript: async def get_transcript(self, session: AsyncSession = None) -> Transcript:
# fetch the transcript # fetch the transcript
result = await transcripts_controller.get_by_id( if session:
transcript_id=self.transcript_id result = await transcripts_controller.get_by_id(session, self.transcript_id)
) else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result: if not result:
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
@@ -175,8 +182,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
async with self.lock_transaction(): async with self.lock_transaction():
async with transcripts_controller.transaction(): async with get_session_factory()() as session:
yield yield session
@broadcast_to_sockets @broadcast_to_sockets
async def on_status(self, status): async def on_status(self, status):
@@ -207,13 +214,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
# when the status of the pipeline changes, update the transcript # when the status of the pipeline changes, update the transcript
async with self._lock: async with self._lock:
return await transcripts_controller.set_status(self.transcript_id, status) async with get_session_factory()() as session:
return await transcripts_controller.set_status(
session, self.transcript_id, status
)
@broadcast_to_sockets @broadcast_to_sockets
async def on_transcript(self, data): async def on_transcript(self, data):
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="TRANSCRIPT", event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation), data=TranscriptText(text=data.text, translation=data.translation),
@@ -230,10 +241,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
) )
if isinstance(data, TitleSummaryWithIdProcessorType): if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id topic.id = data.id
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.upsert_topic(transcript, topic) await transcripts_controller.upsert_topic(session, transcript, topic)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="TOPIC", event="TOPIC",
data=topic, data=topic,
@@ -242,16 +254,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_title(self, data): async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title) final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
if not transcript.title: if not transcript.title:
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"title": final_title.title, "title": final_title.title,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_TITLE", event="FINAL_TITLE",
data=final_title, data=final_title,
@@ -260,15 +274,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_long_summary(self, data): async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"long_summary": final_long_summary.long_summary, "long_summary": final_long_summary.long_summary,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_LONG_SUMMARY", event="FINAL_LONG_SUMMARY",
data=final_long_summary, data=final_long_summary,
@@ -279,15 +295,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
final_short_summary = TranscriptFinalShortSummary( final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary short_summary=data.short_summary
) )
async with self.transaction(): async with self.transaction() as session:
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"short_summary": final_short_summary.short_summary, "short_summary": final_short_summary.short_summary,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_SHORT_SUMMARY", event="FINAL_SHORT_SUMMARY",
data=final_short_summary, data=final_short_summary,
@@ -295,29 +313,30 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_duration(self, data): async def on_duration(self, data):
async with self.transaction(): async with self.transaction() as session:
duration = TranscriptDuration(duration=data) duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"duration": duration.duration, "duration": duration.duration,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration session, transcript=transcript, event="DURATION", data=duration
) )
@broadcast_to_sockets @broadcast_to_sockets
async def on_waveform(self, data): async def on_waveform(self, data):
async with self.transaction(): async with self.transaction() as session:
waveform = TranscriptWaveform(waveform=data) waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript() transcript = await self.get_transcript(session)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform session, transcript=transcript, event="WAVEFORM", data=waveform
) )
@@ -535,7 +554,8 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
return return
# Upload to external storage and delete the file # Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(transcript) async with get_session_factory()() as session:
await transcripts_controller.move_mp3_to_storage(session, transcript)
logger.info("Upload mp3 done") logger.info("Upload mp3 done")
@@ -572,13 +592,20 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
recording = None recording = None
try: try:
if transcript.recording_id: if transcript.recording_id:
recording = await recordings_controller.get_by_id(transcript.recording_id) async with get_session_factory()() as session:
if recording and recording.meeting_id: recording = await recordings_controller.get_by_id(
meeting = await meetings_controller.get_by_id(recording.meeting_id) session, transcript.recording_id
if meeting: )
consent_denied = await meeting_consent_controller.has_any_denial( if recording and recording.meeting_id:
meeting.id meeting = await meetings_controller.get_by_id(
session, recording.meeting_id
) )
if meeting:
consent_denied = (
await meeting_consent_controller.has_any_denial(
session, meeting.id
)
)
except Exception as e: except Exception as e:
logger.error(f"Failed to get fetch consent: {e}", exc_info=e) logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
consent_denied = True consent_denied = True
@@ -606,7 +633,10 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e) logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible # non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(transcript, {"audio_deleted": True}) async with get_session_factory()() as session:
await transcripts_controller.update(
session, transcript, {"audio_deleted": True}
)
# 2. Delete processed audio from transcript storage S3 bucket # 2. Delete processed audio from transcript storage S3 bucket
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
storage = get_transcripts_storage() storage = get_transcripts_storage()
@@ -638,21 +668,24 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
logger.info("Transcript has no recording") logger.info("Transcript has no recording")
return return
recording = await recordings_controller.get_by_id(transcript.recording_id) async with get_session_factory()() as session:
if not recording: recording = await recordings_controller.get_by_id(
logger.info("Recording not found") session, transcript.recording_id
return )
if not recording:
logger.info("Recording not found")
return
if not recording.meeting_id: if not recording.meeting_id:
logger.info("Recording has no meeting") logger.info("Recording has no meeting")
return return
meeting = await meetings_controller.get_by_id(recording.meeting_id) meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
if not meeting: if not meeting:
logger.info("No meeting found for this recording") logger.info("No meeting found for this recording")
return return
room = await rooms_controller.get_by_id(meeting.room_id) room = await rooms_controller.get_by_id(session, meeting.room_id)
if not room: if not room:
logger.error(f"Missing room for a meeting {meeting.id}") logger.error(f"Missing room for a meeting {meeting.id}")
return return
@@ -677,9 +710,10 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
response = await send_message_to_zulip( response = await send_message_to_zulip(
room.zulip_stream, room.zulip_topic, message room.zulip_stream, room.zulip_topic, message
) )
await transcripts_controller.update( async with get_session_factory()() as session:
transcript, {"zulip_message_id": response["id"]} await transcripts_controller.update(
) session, transcript, {"zulip_message_id": response["id"]}
)
logger.info("Posted to zulip") logger.info("Posted to zulip")

View File

@@ -8,9 +8,10 @@ from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel from pydantic import BaseModel
from redis.exceptions import LockError from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session_factory from reflector.db import get_session, get_session_factory
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -185,7 +186,7 @@ async def rooms_list(
session_factory = get_session_factory() session_factory = get_session_factory()
async with session_factory() as session: async with session_factory() as session:
query = await rooms_controller.get_all( query = await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True session, user_id=user_id, order_by="-created_at", return_query=True
) )
return await paginate(session, query) return await paginate(session, query)
@@ -194,9 +195,10 @@ async def rooms_list(
async def rooms_get( async def rooms_get(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id) room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
return room return room
@@ -206,9 +208,10 @@ async def rooms_get(
async def rooms_get_by_name( async def rooms_get_by_name(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -230,10 +233,12 @@ async def rooms_get_by_name(
async def rooms_create( async def rooms_create(
room: CreateRoom, room: CreateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await rooms_controller.add( return await rooms_controller.add(
session,
name=room.name, name=room.name,
user_id=user_id, user_id=user_id,
zulip_auto_post=room.zulip_auto_post, zulip_auto_post=room.zulip_auto_post,
@@ -257,13 +262,14 @@ async def rooms_update(
room_id: str, room_id: str,
info: UpdateRoom, info: UpdateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id) room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
values = info.dict(exclude_unset=True) values = info.dict(exclude_unset=True)
await rooms_controller.update(room, values) await rooms_controller.update(session, room, values)
return room return room
@@ -271,12 +277,13 @@ async def rooms_update(
async def rooms_delete( async def rooms_delete(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id(room_id, user_id=user_id) room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
await rooms_controller.remove_by_id(room.id, user_id=user_id) await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")
@@ -285,9 +292,10 @@ async def rooms_create_meeting(
room_name: str, room_name: str,
info: CreateRoomMeeting, info: CreateRoomMeeting,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -303,7 +311,7 @@ async def rooms_create_meeting(
meeting = None meeting = None
if not info.allow_duplicated: if not info.allow_duplicated:
meeting = await meetings_controller.get_active( meeting = await meetings_controller.get_active(
room=room, current_time=current_time session, room=room, current_time=current_time
) )
if meeting is None: if meeting is None:
@@ -314,6 +322,7 @@ async def rooms_create_meeting(
await upload_logo(whereby_meeting["roomName"], "./images/logo.png") await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
meeting = await meetings_controller.create( meeting = await meetings_controller.create(
session,
id=whereby_meeting["meetingId"], id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"], room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"], room_url=whereby_meeting["roomUrl"],
@@ -340,11 +349,12 @@ async def rooms_create_meeting(
async def rooms_test_webhook( async def rooms_test_webhook(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
"""Test webhook configuration by sending a sample payload.""" """Test webhook configuration by sending a sample payload."""
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id(room_id) room = await rooms_controller.get_by_id(session, room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -361,9 +371,10 @@ async def rooms_test_webhook(
async def rooms_sync_ics( async def rooms_sync_ics(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -390,9 +401,10 @@ async def rooms_sync_ics(
async def rooms_ics_status( async def rooms_ics_status(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -407,7 +419,7 @@ async def rooms_ics_status(
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval) next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False session, room.id, include_deleted=False
) )
return ICSStatus( return ICSStatus(
@@ -423,15 +435,16 @@ async def rooms_ics_status(
async def rooms_list_meetings( async def rooms_list_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False session, room.id, include_deleted=False
) )
if user_id != room.user_id: if user_id != room.user_id:
@@ -449,15 +462,16 @@ async def rooms_list_upcoming_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
minutes_ahead: int = 120, minutes_ahead: int = 120,
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_upcoming( events = await calendar_events_controller.get_upcoming(
room.id, minutes_ahead=minutes_ahead session, room.id, minutes_ahead=minutes_ahead
) )
if user_id != room.user_id: if user_id != room.user_id:
@@ -472,16 +486,17 @@ async def rooms_list_upcoming_meetings(
async def rooms_list_active_meetings( async def rooms_list_active_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
meetings = await meetings_controller.get_all_active_for_room( meetings = await meetings_controller.get_all_active_for_room(
room=room, current_time=current_time session, room=room, current_time=current_time
) )
# Hide host URLs from non-owners # Hide host URLs from non-owners
@@ -497,15 +512,16 @@ async def rooms_get_meeting(
room_name: str, room_name: str,
meeting_id: str, meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
"""Get a single meeting by ID within a specific room.""" """Get a single meeting by ID within a specific room."""
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(meeting_id) meeting = await meetings_controller.get_by_id(session, meeting_id)
if not meeting: if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")
@@ -525,14 +541,15 @@ async def rooms_join_meeting(
room_name: str, room_name: str,
meeting_id: str, meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name) room = await rooms_controller.get_by_name(session, room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(meeting_id) meeting = await meetings_controller.get_by_id(session, meeting_id)
if not meeting: if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")

View File

@@ -9,8 +9,10 @@ from typing import Annotated, Optional
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import AudioWaveform, transcripts_controller from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM from reflector.views.transcripts import ALGORITHM
@@ -48,7 +50,7 @@ async def transcript_get_audio_mp3(
raise unauthorized_exception raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
@@ -96,10 +98,11 @@ async def transcript_get_audio_mp3(
async def transcript_get_audio_waveform( async def transcript_get_audio_waveform(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> AudioWaveform: ) -> AudioWaveform:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if not transcript.audio_waveform_filename.exists(): if not transcript.audio_waveform_filename.exists():

View File

@@ -8,8 +8,10 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus from reflector.views.types import DeletionStatus
@@ -37,10 +39,11 @@ class UpdateParticipant(BaseModel):
async def transcript_get_participants( async def transcript_get_participants(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> list[Participant]: ) -> list[Participant]:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.participants is None: if transcript.participants is None:
@@ -57,10 +60,11 @@ async def transcript_add_participant(
transcript_id: str, transcript_id: str,
participant: CreateParticipant, participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
# ensure the speaker is unique # ensure the speaker is unique
@@ -83,10 +87,11 @@ async def transcript_get_participant(
transcript_id: str, transcript_id: str,
participant_id: str, participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
for p in transcript.participants: for p in transcript.participants:
@@ -102,10 +107,11 @@ async def transcript_update_participant(
participant_id: str, participant_id: str,
participant: UpdateParticipant, participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
# ensure the speaker is unique # ensure the speaker is unique
@@ -139,10 +145,11 @@ async def transcript_delete_participant(
transcript_id: str, transcript_id: str,
participant_id: str, participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> DeletionStatus: ) -> DeletionStatus:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
await transcripts_controller.delete_participant(transcript, participant_id) await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")

View File

@@ -3,8 +3,10 @@ from typing import Annotated, Optional
import celery import celery
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -19,10 +21,11 @@ class ProcessStatus(BaseModel):
async def transcript_process( async def transcript_process(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:

View File

@@ -8,8 +8,10 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
router = APIRouter() router = APIRouter()
@@ -36,10 +38,11 @@ async def transcript_assign_speaker(
transcript_id: str, transcript_id: str,
assignment: SpeakerAssignment, assignment: SpeakerAssignment,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus: ) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if not transcript: if not transcript:
@@ -100,6 +103,7 @@ async def transcript_assign_speaker(
for topic in changed_topics: for topic in changed_topics:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": transcript.topics_dump(), "topics": transcript.topics_dump(),
@@ -114,10 +118,11 @@ async def transcript_merge_speaker(
transcript_id: str, transcript_id: str,
merge: SpeakerMerge, merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus: ) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if not transcript: if not transcript:
@@ -163,6 +168,7 @@ async def transcript_merge_speaker(
for topic in changed_topics: for topic in changed_topics:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": transcript.topics_dump(), "topics": transcript.topics_dump(),

View File

@@ -3,8 +3,10 @@ from typing import Annotated, Optional
import av import av
from fastapi import APIRouter, Depends, HTTPException, UploadFile from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -22,10 +24,11 @@ async def transcript_record_upload(
total_chunks: int, total_chunks: int,
chunk: UploadFile, chunk: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:
@@ -89,7 +92,7 @@ async def transcript_record_upload(
container.close() container.close()
# set the status to "uploaded" # set the status to "uploaded"
await transcripts_controller.update(transcript, {"status": "uploaded"}) await transcripts_controller.update(session, transcript, {"status": "uploaded"})
# launch a background task to process the file # launch a background task to process the file
task_pipeline_file_process.delay(transcript_id=transcript_id) task_pipeline_file_process.delay(transcript_id=transcript_id)

View File

@@ -1,8 +1,10 @@
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base from .rtc_offer import RtcOffer, rtc_offer_base
@@ -16,10 +18,11 @@ async def transcript_record_webrtc(
params: RtcOffer, params: RtcOffer,
request: Request, request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id session, transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:

View File

@@ -24,7 +24,7 @@ async def transcript_events_websocket(
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
): ):
# user_id = user["sub"] if user else None # user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -1,6 +1,4 @@
import os import os
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest import pytest
@@ -34,382 +32,283 @@ def docker_compose_file(pytestconfig):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services): def docker_ip():
"""Ensure that PostgreSQL service is up and responsive.""" """Get Docker IP address for test services"""
port = docker_services.port_for("postgres_test", 5432) # For most Docker setups, localhost works
return "127.0.0.1"
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"dbname": "reflector_test",
"user": "test_user",
"password": "test_password",
}
@pytest.fixture(scope="function", autouse=True) # Only register docker_services dependent fixtures if docker plugin is available
@pytest.mark.asyncio try:
import pytest_docker # noqa: F401
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(
timeout=30.0, pause=0.1, check=is_responsive
)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
except ImportError:
# Docker plugin not available, provide a dummy fixture
@pytest.fixture(scope="session")
def postgres_service(docker_ip):
"""Dummy postgres service when docker plugin is not available"""
return {
"host": docker_ip,
"port": 15432, # Default test postgres port
"database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
@pytest.fixture(scope="session", autouse=True)
async def setup_database(postgres_service): async def setup_database(postgres_service):
from reflector.db import get_engine """Setup database and run migrations"""
from reflector.db.base import metadata from sqlalchemy.ext.asyncio import create_async_engine
async_engine = get_engine() from reflector.db import Base
async with async_engine.begin() as conn: # Build database URL from connection params
await conn.run_sync(metadata.drop_all) db_config = postgres_service
await conn.run_sync(metadata.create_all) DATABASE_URL = (
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
f"@{db_config['host']}:{db_config['port']}/{db_config['database']}"
)
try: # Override settings
yield from reflector.settings import settings
finally:
await async_engine.dispose() settings.DATABASE_URL = DATABASE_URL
# Create engine and tables
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
# Drop all tables first to ensure clean state
await conn.run_sync(Base.metadata.drop_all)
# Create all tables
await conn.run_sync(Base.metadata.create_all)
yield
# Cleanup
await engine.dispose()
@pytest.fixture @pytest.fixture
async def session(): async def session(setup_database):
"""Provide a transactional database session for tests"""
from reflector.db import get_session_factory from reflector.db import get_session_factory
async with get_session_factory()() as session: async with get_session_factory()() as session:
yield session yield session
await session.rollback()
@pytest.fixture @pytest.fixture
def dummy_processors(): def fake_mp3_upload(tmp_path):
with ( """Create a temporary MP3 file for upload testing"""
patch( mp3_file = tmp_path / "test.mp3"
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic" # Create a minimal valid MP3 file (ID3v2 header + minimal frame)
) as mock_topic, mp3_data = b"ID3\x04\x00\x00\x00\x00\x00\x00" + b"\xff\xfb" + b"\x00" * 100
patch( mp3_file.write_bytes(mp3_data)
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title" return mp3_file
) as mock_title,
patch(
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_long_summary"
) as mock_long_summary,
patch(
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
) as mock_short_summary,
):
from reflector.processors.transcript_topic_detector import TopicResponse
mock_topic.return_value = TopicResponse(
title="LLM TITLE", summary="LLM SUMMARY"
)
mock_title.return_value = "LLM Title"
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = "LLM SHORT SUMMARY"
yield (
mock_topic,
mock_title,
mock_long_summary,
mock_short_summary,
) # noqa
@pytest.fixture @pytest.fixture
async def whisper_transcript(): def dummy_transcript():
from reflector.processors.audio_transcript_whisper import ( """Mock transcript processor response"""
AudioTranscriptWhisperProcessor,
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = AudioTranscriptWhisperProcessor()
yield
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
_time_idx = 0
async def _transcript(self, data: AudioFile):
i = self._time_idx
self._time_idx += 2
return Transcript(
text="Hello world.",
words=[
Word(start=i, end=i + 1, text="Hello", speaker=0),
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture
async def dummy_diarization():
from reflector.processors.audio_diarization import AudioDiarizationProcessor
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
_time_idx = 0
async def _diarize(self, data):
i = self._time_idx
self._time_idx += 2
return [
{"start": i, "end": i + 1, "speaker": 0},
{"start": i + 1, "end": i + 2, "speaker": 1},
]
with patch(
"reflector.processors.audio_diarization_auto"
".AudioDiarizationAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioDiarizationProcessor()
yield
@pytest.fixture
async def dummy_file_transcript():
from reflector.processors.file_transcript import FileTranscriptProcessor
from reflector.processors.types import Transcript, Word from reflector.processors.types import Transcript, Word
class TestFileTranscriptProcessor(FileTranscriptProcessor): return Transcript(
async def _transcript(self, data): text="Hello world this is a test",
return Transcript( words=[
text="Hello world. How are you today?", Word(word="Hello", start=0.0, end=0.5, speaker=0),
words=[ Word(word="world", start=0.5, end=1.0, speaker=0),
Word(start=0.0, end=0.5, text="Hello", speaker=0), Word(word="this", start=1.0, end=1.5, speaker=0),
Word(start=0.5, end=0.6, text=" ", speaker=0), Word(word="is", start=1.5, end=1.8, speaker=0),
Word(start=0.6, end=1.0, text="world", speaker=0), Word(word="a", start=1.8, end=2.0, speaker=0),
Word(start=1.0, end=1.1, text=".", speaker=0), Word(word="test", start=2.0, end=2.5, speaker=0),
Word(start=1.1, end=1.2, text=" ", speaker=0), ],
Word(start=1.2, end=1.5, text="How", speaker=0),
Word(start=1.5, end=1.6, text=" ", speaker=0),
Word(start=1.6, end=1.8, text="are", speaker=0),
Word(start=1.8, end=1.9, text=" ", speaker=0),
Word(start=1.9, end=2.1, text="you", speaker=0),
Word(start=2.1, end=2.2, text=" ", speaker=0),
Word(start=2.2, end=2.5, text="today", speaker=0),
Word(start=2.5, end=2.6, text="?", speaker=0),
],
)
with patch(
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileTranscriptProcessor()
yield
@pytest.fixture
async def dummy_file_diarization():
from reflector.processors.file_diarization import (
FileDiarizationOutput,
FileDiarizationProcessor,
) )
from reflector.processors.types import DiarizationSegment
class TestFileDiarizationProcessor(FileDiarizationProcessor):
async def _diarize(self, data):
return FileDiarizationOutput(
diarization=[
DiarizationSegment(start=0.0, end=1.1, speaker=0),
DiarizationSegment(start=1.2, end=2.6, speaker=1),
]
)
with patch(
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileDiarizationProcessor()
yield
@pytest.fixture @pytest.fixture
async def dummy_transcript_translator(): def dummy_transcript_translator():
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor """Mock transcript translation"""
return "Hola mundo esto es una prueba"
class TestTranscriptTranslatorProcessor(TranscriptTranslatorProcessor):
async def _translate(self, text: str) -> str:
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
return f"{source_language}:{target_language}:{text}"
def mock_new(cls, *args, **kwargs):
return TestTranscriptTranslatorProcessor(*args, **kwargs)
with patch(
"reflector.processors.transcript_translator_auto"
".TranscriptTranslatorAutoProcessor.__new__",
mock_new,
):
yield
@pytest.fixture @pytest.fixture
async def dummy_llm(): def dummy_diarization():
from reflector.llm import LLM """Mock diarization processor response"""
from reflector.processors.types import DiarizationOutput, DiarizationSegment
class TestLLM(LLM): return DiarizationOutput(
def __init__(self): diarization=[
self.model_name = "DUMMY MODEL" DiarizationSegment(speaker=0, start=0.0, end=1.0),
self.llm_tokenizer = "DUMMY TOKENIZER" DiarizationSegment(speaker=1, start=1.0, end=2.5),
]
# LLM doesn't have get_instance anymore, mocking constructor instead )
with patch("reflector.llm.LLM") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.fixture @pytest.fixture
async def dummy_storage(): def dummy_file_transcript():
from reflector.storage.base import Storage """Mock file transcript processor response"""
from reflector.processors.types import Transcript, Word
class DummyStorage(Storage): return Transcript(
async def _put_file(self, *args, **kwargs): text="This is a complete file transcript with multiple speakers",
pass words=[
Word(word="This", start=0.0, end=0.5, speaker=0),
async def _delete_file(self, *args, **kwargs): Word(word="is", start=0.5, end=0.8, speaker=0),
pass Word(word="a", start=0.8, end=1.0, speaker=0),
Word(word="complete", start=1.0, end=1.5, speaker=1),
async def _get_file_url(self, *args, **kwargs): Word(word="file", start=1.5, end=1.8, speaker=1),
return "http://fake_server/audio.mp3" Word(word="transcript", start=1.8, end=2.3, speaker=1),
Word(word="with", start=2.3, end=2.5, speaker=0),
async def _get_file(self, *args, **kwargs): Word(word="multiple", start=2.5, end=3.0, speaker=0),
from pathlib import Path Word(word="speakers", start=3.0, end=3.5, speaker=0),
],
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3" )
return test_mp3.read_bytes()
dummy = DummyStorage()
with (
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
patch(
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
) as mock_get_transcripts2,
):
mock_storage.return_value = dummy
mock_get_transcripts.return_value = dummy
mock_get_transcripts2.return_value = dummy
yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
def celery_includes():
return [
"reflector.pipelines.main_live_pipeline",
"reflector.pipelines.main_file_pipeline",
]
@pytest.fixture @pytest.fixture
async def client(): def dummy_file_diarization():
from httpx import AsyncClient """Mock file diarization processor response"""
from reflector.processors.types import DiarizationOutput, DiarizationSegment
from reflector.app import app return DiarizationOutput(
diarization=[
async with AsyncClient(app=app, base_url="http://test/v1") as ac: DiarizationSegment(speaker=0, start=0.0, end=1.0),
yield ac DiarizationSegment(speaker=1, start=1.0, end=2.3),
DiarizationSegment(speaker=0, start=2.3, end=3.5),
]
@pytest.fixture(scope="session") )
def fake_mp3_upload():
with patch(
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
) as mock_move:
mock_move.return_value = True
yield
@pytest.fixture @pytest.fixture
async def fake_transcript_with_topics(tmpdir, client): def fake_transcript_with_topics():
import shutil """Create a transcript with topics for testing"""
from pathlib import Path
from reflector.db.transcripts import TranscriptTopic from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word from reflector.processors.types import Word
from reflector.settings import settings
from reflector.views.transcripts import transcripts_controller
settings.DATA_DIR = Path(tmpdir) topics = [
# create a transcript
response = await client.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename)
# create some topics
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic( TranscriptTopic(
title="Topic 1", id="topic1",
summary="Topic 1 summary", title="Introduction",
timestamp=0, summary="Opening remarks and introductions",
transcript="Hello world", timestamp=0.0,
duration=30.0,
words=[ words=[
Word(text="Hello", start=0, end=1, speaker=0), Word(word="Hello", start=0.0, end=0.5, speaker=0),
Word(text="world", start=1, end=2, speaker=0), Word(word="everyone", start=0.5, end=1.0, speaker=0),
], ],
), ),
)
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic( TranscriptTopic(
title="Topic 2", id="topic2",
summary="Topic 2 summary", title="Main Discussion",
timestamp=2, summary="Core topics and key points",
transcript="Hello world", timestamp=30.0,
duration=60.0,
words=[ words=[
Word(text="Hello", start=2, end=3, speaker=0), Word(word="Let's", start=30.0, end=30.3, speaker=1),
Word(text="world", start=3, end=4, speaker=0), Word(word="discuss", start=30.3, end=30.8, speaker=1),
Word(word="the", start=30.8, end=31.0, speaker=1),
Word(word="agenda", start=31.0, end=31.5, speaker=1),
], ],
), ),
) ]
return topics
yield transcript
@pytest.fixture
def dummy_processors(
dummy_transcript,
dummy_transcript_translator,
dummy_diarization,
dummy_file_transcript,
dummy_file_diarization,
):
"""Mock all processor responses"""
return {
"transcript": dummy_transcript,
"translator": dummy_transcript_translator,
"diarization": dummy_diarization,
"file_transcript": dummy_file_transcript,
"file_diarization": dummy_file_diarization,
}
@pytest.fixture
def dummy_storage():
"""Mock storage backend"""
from unittest.mock import AsyncMock
storage = AsyncMock()
storage.get_file_url.return_value = "https://example.com/test-audio.mp3"
storage.put_file.return_value = None
storage.delete_file.return_value = None
return storage
@pytest.fixture
def dummy_llm():
"""Mock LLM responses"""
return {
"title": "Test Meeting Title",
"summary": "This is a test meeting summary with key discussion points.",
"short_summary": "Brief test summary.",
}
@pytest.fixture
def whisper_transcript():
"""Mock Whisper API response format"""
return {
"text": "Hello world this is a test",
"segments": [
{
"start": 0.0,
"end": 2.5,
"text": "Hello world this is a test",
"words": [
{"word": "Hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0},
{"word": "this", "start": 1.0, "end": 1.5},
{"word": "is", "start": 1.5, "end": 1.8},
{"word": "a", "start": 1.8, "end": 2.0},
{"word": "test", "start": 2.0, "end": 2.5},
],
}
],
"language": "en",
}

View File

@@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.services.ics_sync import ICSSyncService from reflector.services.ics_sync import ICSSyncService
@@ -17,21 +18,22 @@ async def test_attendee_parsing_bug():
instead of properly parsed email addresses. instead of properly parsed email addresses.
""" """
# Create a test room # Create a test room
room = await rooms_controller.add( async with get_session_factory()() as session:
session, room = await rooms_controller.add(
name="test-room", session,
user_id="test-user", name="test-room",
zulip_auto_post=False, user_id="test-user",
zulip_stream="", zulip_auto_post=False,
zulip_topic="", zulip_stream="",
is_locked=False, zulip_topic="",
room_mode="normal", is_locked=False,
recording_type="cloud", room_mode="normal",
recording_trigger="automatic-2nd-participant", recording_type="cloud",
is_shared=False, recording_trigger="automatic-2nd-participant",
ics_url="http://test.com/test.ics", is_shared=False,
ics_enabled=True, ics_url="http://test.com/test.ics",
) ics_enabled=True,
)
# Read the test ICS file that reproduces the bug and update it with current time # Read the test ICS file that reproduces the bug and update it with current time
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@@ -95,99 +97,14 @@ async def test_attendee_parsing_bug():
# This is where the bug manifests - check the attendees # This is where the bug manifests - check the attendees
attendees = event["attendees"] attendees = event["attendees"]
# Print attendee info for debugging # Debug output to see what's happening
print(f"Number of attendees found: {len(attendees)}") print(f"Number of attendees: {len(attendees)}")
for i, attendee in enumerate(attendees): for i, attendee in enumerate(attendees):
print( print(f"Attendee {i}: {attendee}")
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
)
# With the fix, we should now get properly parsed email addresses # The bug would cause 29 attendees (length of "MAILIN01234567890@allo.coop")
# Check that no single characters are parsed as emails # instead of 1 attendee
single_char_emails = [ assert len(attendees) == 1, f"Expected 1 attendee, got {len(attendees)}"
att for att in attendees if att.get("email") and len(att["email"]) == 1
]
if single_char_emails: # Verify the single attendee has correct email
print( assert attendees[0]["email"] == "MAILIN01234567890@allo.coop"
f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
)
for att in single_char_emails:
print(f" - '{att['email']}'")
# Should have attendees but not single-character emails
assert len(attendees) > 0
assert (
len(single_char_emails) == 0
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
# Check that all emails are valid (contain @ symbol)
valid_emails = [
att for att in attendees if att.get("email") and "@" in att["email"]
]
assert len(valid_emails) == len(
attendees
), "Some attendees don't have valid email addresses"
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
assert (
len(attendees) >= 25
), f"Expected around 29 attendees, got {len(attendees)}"
@pytest.mark.asyncio
async def test_correct_attendee_parsing():
"""
Test what correct attendee parsing should look like.
"""
from datetime import datetime, timezone
from icalendar import Event
from reflector.services.ics_sync import ICSFetchService
service = ICSFetchService()
# Create a properly formatted event with multiple attendees
event = Event()
event.add("uid", "test-correct-attendees")
event.add("summary", "Test Meeting")
event.add("location", "http://test.com/test")
event.add("dtstart", datetime.now(timezone.utc))
event.add("dtend", datetime.now(timezone.utc))
# Add attendees the correct way (separate ATTENDEE lines)
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
event.add(
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
)
# Parse the event
result = service._parse_event(event)
assert result is not None
attendees = result["attendees"]
# Should have 4 attendees (3 attendees + 1 organizer)
assert len(attendees) == 4
# Check that all emails are valid email addresses
emails = [att["email"] for att in attendees if att.get("email")]
expected_emails = [
"alice@example.com",
"bob@example.com",
"charlie@example.com",
"organizer@example.com",
]
for email in emails:
assert "@" in email, f"Invalid email format: {email}"
assert len(email) > 5, f"Email too short: {email}"
# Check that we have the expected emails
assert "alice@example.com" in emails
assert "bob@example.com" in emails
assert "charlie@example.com" in emails
assert "organizer@example.com" in emails

View File

@@ -8,6 +8,7 @@ from reflector.db.transcripts import (
SourceKind, SourceKind,
TranscriptController, TranscriptController,
TranscriptTopic, TranscriptTopic,
transcripts_controller,
) )
from reflector.processors.types import Word from reflector.processors.types import Word