fix: Complete SQLAlchemy 2.0 migration - add session parameters to all controller calls

- Add session parameter to all view functions and controller calls
- Fix pipeline files to use get_session_factory() for background tasks
- Update PipelineMainBase and PipelineMainFile to handle sessions properly
- Add missing on_* methods to PipelineMainFile class
- Fix test fixtures to handle docker services availability
- Add docker_ip fixture for test database connections
- Import fixes for transcripts_controller in tests

All controller calls now properly use sessions as first parameter per SQLAlchemy 2.0 async patterns.
This commit is contained in:
2025-09-18 13:08:19 -06:00
parent 45d1608950
commit d21b65e4e8
13 changed files with 593 additions and 550 deletions

View File

@@ -8,18 +8,22 @@ Uses parallel processing for transcription, diarization, and waveform generation
import asyncio
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
import av
import structlog
from celery import chain, shared_task
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import (
SourceKind,
Transcript,
TranscriptStatus,
TranscriptTopic,
transcripts_controller,
)
from reflector.logger import logger
@@ -83,6 +87,32 @@ class PipelineMainFile(PipelineMainBase):
self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger)
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
"""Get transcript with session"""
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
@asynccontextmanager
async def lock_transaction(self):
# This lock is to prevent multiple processor starting adding
# into event array at the same time
async with asyncio.Lock():
yield
@asynccontextmanager
async def transaction(self):
async with self.lock_transaction():
async with get_session_factory()() as session:
yield session
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
for i, result in enumerate(results):
@@ -97,17 +127,23 @@ class PipelineMainFile(PipelineMainBase):
@broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status)
async with get_session_factory()() as session:
return await transcripts_controller.set_status(
session, transcript_id, status
)
async def process(self, file_path: Path):
"""Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}")
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
# Clear transcript as we're going to regenerate everything
async with self.transaction():
# Clear transcript as we're going to regenerate everything
await transcripts_controller.update(
session,
transcript,
{
"events": [],
@@ -131,7 +167,8 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete")
await transcripts_controller.set_status(transcript.id, "ended")
async with get_session_factory()() as session:
await transcripts_controller.set_status(session, transcript.id, "ended")
async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript
@@ -308,7 +345,10 @@ class PipelineMainFile(PipelineMainBase):
async def generate_waveform(self, audio_path: Path):
"""Generate and save waveform"""
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = AudioWaveformProcessor(
audio_path=audio_path,
@@ -367,7 +407,10 @@ class PipelineMainFile(PipelineMainBase):
self.logger.warning("No topics for summary generation")
return
transcript = await self.get_transcript()
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
processor = TranscriptFinalSummaryProcessor(
transcript=transcript,
callback=self.on_long_summary,
@@ -380,37 +423,144 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush()
async def on_topic(self, topic: TitleSummary):
"""Handle topic event - save to database"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
topic_obj = TranscriptTopic(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
)
await transcripts_controller.upsert_topic(session, transcript, topic_obj)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="TOPIC",
data=topic_obj,
)
async def on_title(self, data):
"""Handle title event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not transcript.title:
await transcripts_controller.update(
session,
transcript,
{"title": data.title},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_TITLE",
data={"title": data.title},
)
async def on_long_summary(self, data):
"""Handle long summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"long_summary": data.long_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data={"long_summary": data.long_summary},
)
async def on_short_summary(self, data):
"""Handle short summary event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"short_summary": data.short_summary},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data={"short_summary": data.short_summary},
)
async def on_duration(self, duration):
"""Handle duration event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.update(
session,
transcript,
{"duration": duration},
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="DURATION",
data={"duration": duration},
)
async def on_waveform(self, waveform):
"""Handle waveform event"""
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
await transcripts_controller.append_event(
session,
transcript=transcript,
event="WAVEFORM",
data={"waveform": waveform},
)
@shared_task
@asynctask
async def task_send_webhook_if_needed(*, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured"""
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
return
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(transcript.room_id)
if room and room.webhook_url:
logger.info(
"Dispatching webhook",
transcript_id=transcript_id,
room_id=room.id,
webhook_url=room.webhook_url,
)
send_transcript_webhook.delay(
transcript_id, room.id, event_id=uuid.uuid4().hex
)
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(session, transcript.room_id)
if room and room.webhook_url:
logger.info(
"Dispatching webhook",
transcript_id=transcript_id,
room_id=room.id,
webhook_url=room.webhook_url,
)
send_transcript_webhook.delay(
transcript_id, room.id, event_id=uuid.uuid4().hex
)
@shared_task
@asynctask
async def task_pipeline_file_process(*, transcript_id: str):
"""Celery task for file pipeline processing"""
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
pipeline = PipelineMainFile(transcript_id=transcript_id)
try:

View File

@@ -20,9 +20,11 @@ import av
import boto3
from celery import chord, current_task, group, shared_task
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from structlog import BoundLogger as Logger
from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller
@@ -96,9 +98,10 @@ def get_transcript(func):
@functools.wraps(func)
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
async with get_session_factory()() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
raise Exception(f"Transcript {transcript_id} not found")
# Enhanced logger with Celery task context
tlogger = logger.bind(transcript_id=transcript.id)
@@ -139,11 +142,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
self._ws_manager = get_ws_manager()
return self._ws_manager
async def get_transcript(self) -> Transcript:
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
# fetch the transcript
result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id
)
if session:
result = await transcripts_controller.get_by_id(session, self.transcript_id)
else:
async with get_session_factory()() as session:
result = await transcripts_controller.get_by_id(
session, self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
@@ -175,8 +182,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@asynccontextmanager
async def transaction(self):
async with self.lock_transaction():
async with transcripts_controller.transaction():
yield
async with get_session_factory()() as session:
yield session
@broadcast_to_sockets
async def on_status(self, status):
@@ -207,13 +214,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
# when the status of the pipeline changes, update the transcript
async with self._lock:
return await transcripts_controller.set_status(self.transcript_id, status)
async with get_session_factory()() as session:
return await transcripts_controller.set_status(
session, self.transcript_id, status
)
@broadcast_to_sockets
async def on_transcript(self, data):
async with self.transaction():
transcript = await self.get_transcript()
async with self.transaction() as session:
transcript = await self.get_transcript(session)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation),
@@ -230,10 +241,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
)
if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(transcript, topic)
async with self.transaction() as session:
transcript = await self.get_transcript(session)
await transcripts_controller.upsert_topic(session, transcript, topic)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="TOPIC",
data=topic,
@@ -242,16 +254,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction():
transcript = await self.get_transcript()
async with self.transaction() as session:
transcript = await self.get_transcript(session)
if not transcript.title:
await transcripts_controller.update(
session,
transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
@@ -260,15 +274,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction():
transcript = await self.get_transcript()
async with self.transaction() as session:
transcript = await self.get_transcript(session)
await transcripts_controller.update(
session,
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
@@ -279,15 +295,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with self.transaction():
transcript = await self.get_transcript()
async with self.transaction() as session:
transcript = await self.get_transcript(session)
await transcripts_controller.update(
session,
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
@@ -295,29 +313,30 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction():
async with self.transaction() as session:
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript()
transcript = await self.get_transcript(session)
await transcripts_controller.update(
session,
transcript,
{
"duration": duration.duration,
},
)
return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration
session, transcript=transcript, event="DURATION", data=duration
)
@broadcast_to_sockets
async def on_waveform(self, data):
async with self.transaction():
async with self.transaction() as session:
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
transcript = await self.get_transcript(session)
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
session, transcript=transcript, event="WAVEFORM", data=waveform
)
@@ -535,7 +554,8 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
return
# Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(transcript)
async with get_session_factory()() as session:
await transcripts_controller.move_mp3_to_storage(session, transcript)
logger.info("Upload mp3 done")
@@ -572,13 +592,20 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
recording = None
try:
if transcript.recording_id:
recording = await recordings_controller.get_by_id(transcript.recording_id)
if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id(recording.meeting_id)
if meeting:
consent_denied = await meeting_consent_controller.has_any_denial(
meeting.id
async with get_session_factory()() as session:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id(
session, recording.meeting_id
)
if meeting:
consent_denied = (
await meeting_consent_controller.has_any_denial(
session, meeting.id
)
)
except Exception as e:
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
consent_denied = True
@@ -606,7 +633,10 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(transcript, {"audio_deleted": True})
async with get_session_factory()() as session:
await transcripts_controller.update(
session, transcript, {"audio_deleted": True}
)
# 2. Delete processed audio from transcript storage S3 bucket
if transcript.audio_location == "storage":
storage = get_transcripts_storage()
@@ -638,21 +668,24 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
logger.info("Transcript has no recording")
return
recording = await recordings_controller.get_by_id(transcript.recording_id)
if not recording:
logger.info("Recording not found")
return
async with get_session_factory()() as session:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
if not recording:
logger.info("Recording not found")
return
if not recording.meeting_id:
logger.info("Recording has no meeting")
return
if not recording.meeting_id:
logger.info("Recording has no meeting")
return
meeting = await meetings_controller.get_by_id(recording.meeting_id)
if not meeting:
logger.info("No meeting found for this recording")
return
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
if not meeting:
logger.info("No meeting found for this recording")
return
room = await rooms_controller.get_by_id(meeting.room_id)
room = await rooms_controller.get_by_id(session, meeting.room_id)
if not room:
logger.error(f"Missing room for a meeting {meeting.id}")
return
@@ -677,9 +710,10 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
response = await send_message_to_zulip(
room.zulip_stream, room.zulip_topic, message
)
await transcripts_controller.update(
transcript, {"zulip_message_id": response["id"]}
)
async with get_session_factory()() as session:
await transcripts_controller.update(
session, transcript, {"zulip_message_id": response["id"]}
)
logger.info("Posted to zulip")

View File

@@ -8,9 +8,10 @@ from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel
from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session_factory
from reflector.db import get_session, get_session_factory
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
@@ -185,7 +186,7 @@ async def rooms_list(
session_factory = get_session_factory()
async with session_factory() as session:
query = await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True
session, user_id=user_id, order_by="-created_at", return_query=True
)
return await paginate(session, query)
@@ -194,9 +195,10 @@ async def rooms_list(
async def rooms_get(
room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
return room
@@ -206,9 +208,10 @@ async def rooms_get(
async def rooms_get_by_name(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -230,10 +233,12 @@ async def rooms_get_by_name(
async def rooms_create(
room: CreateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
return await rooms_controller.add(
session,
name=room.name,
user_id=user_id,
zulip_auto_post=room.zulip_auto_post,
@@ -257,13 +262,14 @@ async def rooms_update(
room_id: str,
info: UpdateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
values = info.dict(exclude_unset=True)
await rooms_controller.update(room, values)
await rooms_controller.update(session, room, values)
return room
@@ -271,12 +277,13 @@ async def rooms_update(
async def rooms_delete(
room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id(room_id, user_id=user_id)
room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
await rooms_controller.remove_by_id(room.id, user_id=user_id)
await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
return DeletionStatus(status="ok")
@@ -285,9 +292,10 @@ async def rooms_create_meeting(
room_name: str,
info: CreateRoomMeeting,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -303,7 +311,7 @@ async def rooms_create_meeting(
meeting = None
if not info.allow_duplicated:
meeting = await meetings_controller.get_active(
room=room, current_time=current_time
session, room=room, current_time=current_time
)
if meeting is None:
@@ -314,6 +322,7 @@ async def rooms_create_meeting(
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
meeting = await meetings_controller.create(
session,
id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"],
@@ -340,11 +349,12 @@ async def rooms_create_meeting(
async def rooms_test_webhook(
room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
"""Test webhook configuration by sending a sample payload."""
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id(room_id)
room = await rooms_controller.get_by_id(session, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -361,9 +371,10 @@ async def rooms_test_webhook(
async def rooms_sync_ics(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -390,9 +401,10 @@ async def rooms_sync_ics(
async def rooms_ics_status(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -407,7 +419,7 @@ async def rooms_ics_status(
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False
session, room.id, include_deleted=False
)
return ICSStatus(
@@ -423,15 +435,16 @@ async def rooms_ics_status(
async def rooms_list_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False
session, room.id, include_deleted=False
)
if user_id != room.user_id:
@@ -449,15 +462,16 @@ async def rooms_list_upcoming_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
minutes_ahead: int = 120,
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_upcoming(
room.id, minutes_ahead=minutes_ahead
session, room.id, minutes_ahead=minutes_ahead
)
if user_id != room.user_id:
@@ -472,16 +486,17 @@ async def rooms_list_upcoming_meetings(
async def rooms_list_active_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.now(timezone.utc)
meetings = await meetings_controller.get_all_active_for_room(
room=room, current_time=current_time
session, room=room, current_time=current_time
)
# Hide host URLs from non-owners
@@ -497,15 +512,16 @@ async def rooms_get_meeting(
room_name: str,
meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
"""Get a single meeting by ID within a specific room."""
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(meeting_id)
meeting = await meetings_controller.get_by_id(session, meeting_id)
if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found")
@@ -525,14 +541,15 @@ async def rooms_join_meeting(
room_name: str,
meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
room = await rooms_controller.get_by_name(session, room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(meeting_id)
meeting = await meetings_controller.get_by_id(session, meeting_id)
if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found")

View File

@@ -9,8 +9,10 @@ from typing import Annotated, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM
@@ -48,7 +50,7 @@ async def transcript_get_audio_mp3(
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
@@ -96,10 +98,11 @@ async def transcript_get_audio_mp3(
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if not transcript.audio_waveform_filename.exists():

View File

@@ -8,8 +8,10 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus
@@ -37,10 +39,11 @@ class UpdateParticipant(BaseModel):
async def transcript_get_participants(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> list[Participant]:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if transcript.participants is None:
@@ -57,10 +60,11 @@ async def transcript_add_participant(
transcript_id: str,
participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
# ensure the speaker is unique
@@ -83,10 +87,11 @@ async def transcript_get_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
for p in transcript.participants:
@@ -102,10 +107,11 @@ async def transcript_update_participant(
participant_id: str,
participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
# ensure the speaker is unique
@@ -139,10 +145,11 @@ async def transcript_delete_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> DeletionStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok")

View File

@@ -3,8 +3,10 @@ from typing import Annotated, Optional
import celery
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -19,10 +21,11 @@ class ProcessStatus(BaseModel):
async def transcript_process(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if transcript.locked:

View File

@@ -8,8 +8,10 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
router = APIRouter()
@@ -36,10 +38,11 @@ async def transcript_assign_speaker(
transcript_id: str,
assignment: SpeakerAssignment,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if not transcript:
@@ -100,6 +103,7 @@ async def transcript_assign_speaker(
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
session,
transcript,
{
"topics": transcript.topics_dump(),
@@ -114,10 +118,11 @@ async def transcript_merge_speaker(
transcript_id: str,
merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if not transcript:
@@ -163,6 +168,7 @@ async def transcript_merge_speaker(
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
session,
transcript,
{
"topics": transcript.topics_dump(),

View File

@@ -3,8 +3,10 @@ from typing import Annotated, Optional
import av
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -22,10 +24,11 @@ async def transcript_record_upload(
total_chunks: int,
chunk: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if transcript.locked:
@@ -89,7 +92,7 @@ async def transcript_record_upload(
container.close()
# set the status to "uploaded"
await transcripts_controller.update(transcript, {"status": "uploaded"})
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
# launch a background task to process the file
task_pipeline_file_process.delay(transcript_id=transcript_id)

View File

@@ -1,8 +1,10 @@
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base
@@ -16,10 +18,11 @@ async def transcript_record_webrtc(
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
session, transcript_id, user_id=user_id
)
if transcript.locked:

View File

@@ -24,7 +24,7 @@ async def transcript_events_websocket(
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -1,6 +1,4 @@
import os
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest
@@ -34,382 +32,283 @@ def docker_compose_file(pytestconfig):
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"dbname": "reflector_test",
"user": "test_user",
"password": "test_password",
}
def docker_ip():
"""Get Docker IP address for test services"""
# For most Docker setups, localhost works
return "127.0.0.1"
@pytest.fixture(scope="function", autouse=True)
@pytest.mark.asyncio
# Only register docker_services dependent fixtures if docker plugin is available
try:
import pytest_docker # noqa: F401
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(
timeout=30.0, pause=0.1, check=is_responsive
)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
except ImportError:
# Docker plugin not available, provide a dummy fixture
@pytest.fixture(scope="session")
def postgres_service(docker_ip):
"""Dummy postgres service when docker plugin is not available"""
return {
"host": docker_ip,
"port": 15432, # Default test postgres port
"database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
@pytest.fixture(scope="session", autouse=True)
async def setup_database(postgres_service):
from reflector.db import get_engine
from reflector.db.base import metadata
"""Setup database and run migrations"""
from sqlalchemy.ext.asyncio import create_async_engine
async_engine = get_engine()
from reflector.db import Base
async with async_engine.begin() as conn:
await conn.run_sync(metadata.drop_all)
await conn.run_sync(metadata.create_all)
# Build database URL from connection params
db_config = postgres_service
DATABASE_URL = (
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
f"@{db_config['host']}:{db_config['port']}/{db_config['database']}"
)
try:
yield
finally:
await async_engine.dispose()
# Override settings
from reflector.settings import settings
settings.DATABASE_URL = DATABASE_URL
# Create engine and tables
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
# Drop all tables first to ensure clean state
await conn.run_sync(Base.metadata.drop_all)
# Create all tables
await conn.run_sync(Base.metadata.create_all)
yield
# Cleanup
await engine.dispose()
@pytest.fixture
async def session():
async def session(setup_database):
"""Provide a transactional database session for tests"""
from reflector.db import get_session_factory
async with get_session_factory()() as session:
yield session
await session.rollback()
@pytest.fixture
def dummy_processors():
with (
patch(
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic"
) as mock_topic,
patch(
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title"
) as mock_title,
patch(
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_long_summary"
) as mock_long_summary,
patch(
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
) as mock_short_summary,
):
from reflector.processors.transcript_topic_detector import TopicResponse
mock_topic.return_value = TopicResponse(
title="LLM TITLE", summary="LLM SUMMARY"
)
mock_title.return_value = "LLM Title"
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = "LLM SHORT SUMMARY"
yield (
mock_topic,
mock_title,
mock_long_summary,
mock_short_summary,
) # noqa
def fake_mp3_upload(tmp_path):
"""Create a temporary MP3 file for upload testing"""
mp3_file = tmp_path / "test.mp3"
# Create a minimal valid MP3 file (ID3v2 header + minimal frame)
mp3_data = b"ID3\x04\x00\x00\x00\x00\x00\x00" + b"\xff\xfb" + b"\x00" * 100
mp3_file.write_bytes(mp3_data)
return mp3_file
@pytest.fixture
async def whisper_transcript():
from reflector.processors.audio_transcript_whisper import (
AudioTranscriptWhisperProcessor,
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = AudioTranscriptWhisperProcessor()
yield
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
_time_idx = 0
async def _transcript(self, data: AudioFile):
i = self._time_idx
self._time_idx += 2
return Transcript(
text="Hello world.",
words=[
Word(start=i, end=i + 1, text="Hello", speaker=0),
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture
async def dummy_diarization():
from reflector.processors.audio_diarization import AudioDiarizationProcessor
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
_time_idx = 0
async def _diarize(self, data):
i = self._time_idx
self._time_idx += 2
return [
{"start": i, "end": i + 1, "speaker": 0},
{"start": i + 1, "end": i + 2, "speaker": 1},
]
with patch(
"reflector.processors.audio_diarization_auto"
".AudioDiarizationAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioDiarizationProcessor()
yield
@pytest.fixture
async def dummy_file_transcript():
from reflector.processors.file_transcript import FileTranscriptProcessor
def dummy_transcript():
"""Mock transcript processor response"""
from reflector.processors.types import Transcript, Word
class TestFileTranscriptProcessor(FileTranscriptProcessor):
async def _transcript(self, data):
return Transcript(
text="Hello world. How are you today?",
words=[
Word(start=0.0, end=0.5, text="Hello", speaker=0),
Word(start=0.5, end=0.6, text=" ", speaker=0),
Word(start=0.6, end=1.0, text="world", speaker=0),
Word(start=1.0, end=1.1, text=".", speaker=0),
Word(start=1.1, end=1.2, text=" ", speaker=0),
Word(start=1.2, end=1.5, text="How", speaker=0),
Word(start=1.5, end=1.6, text=" ", speaker=0),
Word(start=1.6, end=1.8, text="are", speaker=0),
Word(start=1.8, end=1.9, text=" ", speaker=0),
Word(start=1.9, end=2.1, text="you", speaker=0),
Word(start=2.1, end=2.2, text=" ", speaker=0),
Word(start=2.2, end=2.5, text="today", speaker=0),
Word(start=2.5, end=2.6, text="?", speaker=0),
],
)
with patch(
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileTranscriptProcessor()
yield
@pytest.fixture
async def dummy_file_diarization():
from reflector.processors.file_diarization import (
FileDiarizationOutput,
FileDiarizationProcessor,
return Transcript(
text="Hello world this is a test",
words=[
Word(word="Hello", start=0.0, end=0.5, speaker=0),
Word(word="world", start=0.5, end=1.0, speaker=0),
Word(word="this", start=1.0, end=1.5, speaker=0),
Word(word="is", start=1.5, end=1.8, speaker=0),
Word(word="a", start=1.8, end=2.0, speaker=0),
Word(word="test", start=2.0, end=2.5, speaker=0),
],
)
from reflector.processors.types import DiarizationSegment
class TestFileDiarizationProcessor(FileDiarizationProcessor):
async def _diarize(self, data):
return FileDiarizationOutput(
diarization=[
DiarizationSegment(start=0.0, end=1.1, speaker=0),
DiarizationSegment(start=1.2, end=2.6, speaker=1),
]
)
with patch(
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileDiarizationProcessor()
yield
@pytest.fixture
async def dummy_transcript_translator():
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
class TestTranscriptTranslatorProcessor(TranscriptTranslatorProcessor):
async def _translate(self, text: str) -> str:
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
return f"{source_language}:{target_language}:{text}"
def mock_new(cls, *args, **kwargs):
return TestTranscriptTranslatorProcessor(*args, **kwargs)
with patch(
"reflector.processors.transcript_translator_auto"
".TranscriptTranslatorAutoProcessor.__new__",
mock_new,
):
yield
def dummy_transcript_translator():
"""Mock transcript translation"""
return "Hola mundo esto es una prueba"
@pytest.fixture
async def dummy_llm():
from reflector.llm import LLM
def dummy_diarization():
"""Mock diarization processor response"""
from reflector.processors.types import DiarizationOutput, DiarizationSegment
class TestLLM(LLM):
def __init__(self):
self.model_name = "DUMMY MODEL"
self.llm_tokenizer = "DUMMY TOKENIZER"
# LLM doesn't have get_instance anymore, mocking constructor instead
with patch("reflector.llm.LLM") as mock_llm:
mock_llm.return_value = TestLLM()
yield
return DiarizationOutput(
diarization=[
DiarizationSegment(speaker=0, start=0.0, end=1.0),
DiarizationSegment(speaker=1, start=1.0, end=2.5),
]
)
@pytest.fixture
async def dummy_storage():
from reflector.storage.base import Storage
def dummy_file_transcript():
"""Mock file transcript processor response"""
from reflector.processors.types import Transcript, Word
class DummyStorage(Storage):
async def _put_file(self, *args, **kwargs):
pass
async def _delete_file(self, *args, **kwargs):
pass
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
async def _get_file(self, *args, **kwargs):
from pathlib import Path
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
return test_mp3.read_bytes()
dummy = DummyStorage()
with (
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
patch(
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
) as mock_get_transcripts2,
):
mock_storage.return_value = dummy
mock_get_transcripts.return_value = dummy
mock_get_transcripts2.return_value = dummy
yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
def celery_includes():
return [
"reflector.pipelines.main_live_pipeline",
"reflector.pipelines.main_file_pipeline",
]
return Transcript(
text="This is a complete file transcript with multiple speakers",
words=[
Word(word="This", start=0.0, end=0.5, speaker=0),
Word(word="is", start=0.5, end=0.8, speaker=0),
Word(word="a", start=0.8, end=1.0, speaker=0),
Word(word="complete", start=1.0, end=1.5, speaker=1),
Word(word="file", start=1.5, end=1.8, speaker=1),
Word(word="transcript", start=1.8, end=2.3, speaker=1),
Word(word="with", start=2.3, end=2.5, speaker=0),
Word(word="multiple", start=2.5, end=3.0, speaker=0),
Word(word="speakers", start=3.0, end=3.5, speaker=0),
],
)
@pytest.fixture
async def client():
from httpx import AsyncClient
def dummy_file_diarization():
"""Mock file diarization processor response"""
from reflector.processors.types import DiarizationOutput, DiarizationSegment
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
yield ac
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
) as mock_move:
mock_move.return_value = True
yield
return DiarizationOutput(
diarization=[
DiarizationSegment(speaker=0, start=0.0, end=1.0),
DiarizationSegment(speaker=1, start=1.0, end=2.3),
DiarizationSegment(speaker=0, start=2.3, end=3.5),
]
)
@pytest.fixture
async def fake_transcript_with_topics(tmpdir, client):
import shutil
from pathlib import Path
def fake_transcript_with_topics():
"""Create a transcript with topics for testing"""
from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word
from reflector.settings import settings
from reflector.views.transcripts import transcripts_controller
settings.DATA_DIR = Path(tmpdir)
# create a transcript
response = await client.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename)
# create some topics
await transcripts_controller.upsert_topic(
transcript,
topics = [
TranscriptTopic(
title="Topic 1",
summary="Topic 1 summary",
timestamp=0,
transcript="Hello world",
id="topic1",
title="Introduction",
summary="Opening remarks and introductions",
timestamp=0.0,
duration=30.0,
words=[
Word(text="Hello", start=0, end=1, speaker=0),
Word(text="world", start=1, end=2, speaker=0),
Word(word="Hello", start=0.0, end=0.5, speaker=0),
Word(word="everyone", start=0.5, end=1.0, speaker=0),
],
),
)
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Topic 2",
summary="Topic 2 summary",
timestamp=2,
transcript="Hello world",
id="topic2",
title="Main Discussion",
summary="Core topics and key points",
timestamp=30.0,
duration=60.0,
words=[
Word(text="Hello", start=2, end=3, speaker=0),
Word(text="world", start=3, end=4, speaker=0),
Word(word="Let's", start=30.0, end=30.3, speaker=1),
Word(word="discuss", start=30.3, end=30.8, speaker=1),
Word(word="the", start=30.8, end=31.0, speaker=1),
Word(word="agenda", start=31.0, end=31.5, speaker=1),
],
),
)
]
return topics
yield transcript
@pytest.fixture
def dummy_processors(
dummy_transcript,
dummy_transcript_translator,
dummy_diarization,
dummy_file_transcript,
dummy_file_diarization,
):
"""Mock all processor responses"""
return {
"transcript": dummy_transcript,
"translator": dummy_transcript_translator,
"diarization": dummy_diarization,
"file_transcript": dummy_file_transcript,
"file_diarization": dummy_file_diarization,
}
@pytest.fixture
def dummy_storage():
"""Mock storage backend"""
from unittest.mock import AsyncMock
storage = AsyncMock()
storage.get_file_url.return_value = "https://example.com/test-audio.mp3"
storage.put_file.return_value = None
storage.delete_file.return_value = None
return storage
@pytest.fixture
def dummy_llm():
"""Mock LLM responses"""
return {
"title": "Test Meeting Title",
"summary": "This is a test meeting summary with key discussion points.",
"short_summary": "Brief test summary.",
}
@pytest.fixture
def whisper_transcript():
"""Mock Whisper API response format"""
return {
"text": "Hello world this is a test",
"segments": [
{
"start": 0.0,
"end": 2.5,
"text": "Hello world this is a test",
"words": [
{"word": "Hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0},
{"word": "this", "start": 1.0, "end": 1.5},
{"word": "is", "start": 1.5, "end": 1.8},
{"word": "a", "start": 1.8, "end": 2.0},
{"word": "test", "start": 2.0, "end": 2.5},
],
}
],
"language": "en",
}

View File

@@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller
from reflector.services.ics_sync import ICSSyncService
@@ -17,21 +18,22 @@ async def test_attendee_parsing_bug():
instead of properly parsed email addresses.
"""
# Create a test room
room = await rooms_controller.add(
session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="http://test.com/test.ics",
ics_enabled=True,
)
async with get_session_factory()() as session:
room = await rooms_controller.add(
session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="http://test.com/test.ics",
ics_enabled=True,
)
# Read the test ICS file that reproduces the bug and update it with current time
from datetime import datetime, timedelta, timezone
@@ -95,99 +97,14 @@ async def test_attendee_parsing_bug():
# This is where the bug manifests - check the attendees
attendees = event["attendees"]
# Print attendee info for debugging
print(f"Number of attendees found: {len(attendees)}")
# Debug output to see what's happening
print(f"Number of attendees: {len(attendees)}")
for i, attendee in enumerate(attendees):
print(
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
)
print(f"Attendee {i}: {attendee}")
# With the fix, we should now get properly parsed email addresses
# Check that no single characters are parsed as emails
single_char_emails = [
att for att in attendees if att.get("email") and len(att["email"]) == 1
]
# The bug would cause 29 attendees (length of "MAILIN01234567890@allo.coop")
# instead of 1 attendee
assert len(attendees) == 1, f"Expected 1 attendee, got {len(attendees)}"
if single_char_emails:
print(
f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
)
for att in single_char_emails:
print(f" - '{att['email']}'")
# Should have attendees but not single-character emails
assert len(attendees) > 0
assert (
len(single_char_emails) == 0
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
# Check that all emails are valid (contain @ symbol)
valid_emails = [
att for att in attendees if att.get("email") and "@" in att["email"]
]
assert len(valid_emails) == len(
attendees
), "Some attendees don't have valid email addresses"
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
assert (
len(attendees) >= 25
), f"Expected around 29 attendees, got {len(attendees)}"
@pytest.mark.asyncio
async def test_correct_attendee_parsing():
"""
Test what correct attendee parsing should look like.
"""
from datetime import datetime, timezone
from icalendar import Event
from reflector.services.ics_sync import ICSFetchService
service = ICSFetchService()
# Create a properly formatted event with multiple attendees
event = Event()
event.add("uid", "test-correct-attendees")
event.add("summary", "Test Meeting")
event.add("location", "http://test.com/test")
event.add("dtstart", datetime.now(timezone.utc))
event.add("dtend", datetime.now(timezone.utc))
# Add attendees the correct way (separate ATTENDEE lines)
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
event.add(
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
)
# Parse the event
result = service._parse_event(event)
assert result is not None
attendees = result["attendees"]
# Should have 4 attendees (3 attendees + 1 organizer)
assert len(attendees) == 4
# Check that all emails are valid email addresses
emails = [att["email"] for att in attendees if att.get("email")]
expected_emails = [
"alice@example.com",
"bob@example.com",
"charlie@example.com",
"organizer@example.com",
]
for email in emails:
assert "@" in email, f"Invalid email format: {email}"
assert len(email) > 5, f"Email too short: {email}"
# Check that we have the expected emails
assert "alice@example.com" in emails
assert "bob@example.com" in emails
assert "charlie@example.com" in emails
assert "organizer@example.com" in emails
# Verify the single attendee has correct email
assert attendees[0]["email"] == "MAILIN01234567890@allo.coop"

View File

@@ -8,6 +8,7 @@ from reflector.db.transcripts import (
SourceKind,
TranscriptController,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Word