mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
fix: Complete SQLAlchemy 2.0 migration - add session parameters to all controller calls
- Add session parameter to all view functions and controller calls - Fix pipeline files to use get_session_factory() for background tasks - Update PipelineMainBase and PipelineMainFile to handle sessions properly - Add missing on_* methods to PipelineMainFile class - Fix test fixtures to handle docker services availability - Add docker_ip fixture for test database connections - Import fixes for transcripts_controller in tests All controller calls now properly use sessions as first parameter per SQLAlchemy 2.0 async patterns.
This commit is contained in:
@@ -8,18 +8,22 @@ Uses parallel processing for transcription, diarization, and waveform generation
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import structlog
|
||||
from celery import chain, shared_task
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
Transcript,
|
||||
TranscriptStatus,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
@@ -83,6 +87,32 @@ class PipelineMainFile(PipelineMainBase):
|
||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||
|
||||
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
||||
"""Get transcript with session"""
|
||||
if session:
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
else:
|
||||
async with get_session_factory()() as session:
|
||||
result = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock_transaction(self):
|
||||
# This lock is to prevent multiple processor starting adding
|
||||
# into event array at the same time
|
||||
async with asyncio.Lock():
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self.lock_transaction():
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
|
||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||
for i, result in enumerate(results):
|
||||
@@ -97,17 +127,23 @@ class PipelineMainFile(PipelineMainBase):
|
||||
@broadcast_to_sockets
|
||||
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||
async with self.lock_transaction():
|
||||
return await transcripts_controller.set_status(transcript_id, status)
|
||||
async with get_session_factory()() as session:
|
||||
return await transcripts_controller.set_status(
|
||||
session, transcript_id, status
|
||||
)
|
||||
|
||||
async def process(self, file_path: Path):
|
||||
"""Main entry point for file processing"""
|
||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
async with self.transaction():
|
||||
# Clear transcript as we're going to regenerate everything
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"events": [],
|
||||
@@ -131,7 +167,8 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
self.logger.info("File pipeline complete")
|
||||
|
||||
await transcripts_controller.set_status(transcript.id, "ended")
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.set_status(session, transcript.id, "ended")
|
||||
|
||||
async def extract_and_write_audio(
|
||||
self, file_path: Path, transcript: Transcript
|
||||
@@ -308,7 +345,10 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
async def generate_waveform(self, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
@@ -367,7 +407,10 @@ class PipelineMainFile(PipelineMainBase):
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
@@ -380,37 +423,144 @@ class PipelineMainFile(PipelineMainBase):
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def on_topic(self, topic: TitleSummary):
|
||||
"""Handle topic event - save to database"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
topic_obj = TranscriptTopic(
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
)
|
||||
await transcripts_controller.upsert_topic(session, transcript, topic_obj)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic_obj,
|
||||
)
|
||||
|
||||
async def on_title(self, data):
|
||||
"""Handle title event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"title": data.title},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data={"title": data.title},
|
||||
)
|
||||
|
||||
async def on_long_summary(self, data):
|
||||
"""Handle long summary event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"long_summary": data.long_summary},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data={"long_summary": data.long_summary},
|
||||
)
|
||||
|
||||
async def on_short_summary(self, data):
|
||||
"""Handle short summary event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"short_summary": data.short_summary},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data={"short_summary": data.short_summary},
|
||||
)
|
||||
|
||||
async def on_duration(self, duration):
|
||||
"""Handle duration event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{"duration": duration},
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="DURATION",
|
||||
data={"duration": duration},
|
||||
)
|
||||
|
||||
async def on_waveform(self, waveform):
|
||||
"""Handle waveform event"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="WAVEFORM",
|
||||
data={"waveform": waveform},
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_send_webhook_if_needed(*, transcript_id: str):
|
||||
"""Send webhook if this is a room recording with webhook configured"""
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||
if room and room.webhook_url:
|
||||
logger.info(
|
||||
"Dispatching webhook",
|
||||
transcript_id=transcript_id,
|
||||
room_id=room.id,
|
||||
webhook_url=room.webhook_url,
|
||||
)
|
||||
send_transcript_webhook.delay(
|
||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_file_process(*, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
try:
|
||||
|
||||
@@ -20,9 +20,11 @@ import av
|
||||
import boto3
|
||||
from celery import chord, current_task, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -96,9 +98,10 @@ def get_transcript(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
# Enhanced logger with Celery task context
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
@@ -139,11 +142,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
self._ws_manager = get_ws_manager()
|
||||
return self._ws_manager
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
||||
# fetch the transcript
|
||||
result = await transcripts_controller.get_by_id(
|
||||
transcript_id=self.transcript_id
|
||||
)
|
||||
if session:
|
||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
else:
|
||||
async with get_session_factory()() as session:
|
||||
result = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
if not result:
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
@@ -175,8 +182,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self.lock_transaction():
|
||||
async with transcripts_controller.transaction():
|
||||
yield
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_status(self, status):
|
||||
@@ -207,13 +214,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
|
||||
# when the status of the pipeline changes, update the transcript
|
||||
async with self._lock:
|
||||
return await transcripts_controller.set_status(self.transcript_id, status)
|
||||
async with get_session_factory()() as session:
|
||||
return await transcripts_controller.set_status(
|
||||
session, self.transcript_id, status
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_transcript(self, data):
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TRANSCRIPT",
|
||||
data=TranscriptText(text=data.text, translation=data.translation),
|
||||
@@ -230,10 +241,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
)
|
||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||
topic.id = data.id
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
await transcripts_controller.upsert_topic(transcript, topic)
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="TOPIC",
|
||||
data=topic,
|
||||
@@ -242,16 +254,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_title(self, data):
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
if not transcript.title:
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_TITLE",
|
||||
data=final_title,
|
||||
@@ -260,15 +274,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
@broadcast_to_sockets
|
||||
async def on_long_summary(self, data):
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_LONG_SUMMARY",
|
||||
data=final_long_summary,
|
||||
@@ -279,15 +295,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
async with self.transaction() as session:
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
session,
|
||||
transcript=transcript,
|
||||
event="FINAL_SHORT_SUMMARY",
|
||||
data=final_short_summary,
|
||||
@@ -295,29 +313,30 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_duration(self, data):
|
||||
async with self.transaction():
|
||||
async with self.transaction() as session:
|
||||
duration = TranscriptDuration(duration=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await self.get_transcript(session)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"duration": duration.duration,
|
||||
},
|
||||
)
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="DURATION", data=duration
|
||||
session, transcript=transcript, event="DURATION", data=duration
|
||||
)
|
||||
|
||||
@broadcast_to_sockets
|
||||
async def on_waveform(self, data):
|
||||
async with self.transaction():
|
||||
async with self.transaction() as session:
|
||||
waveform = TranscriptWaveform(waveform=data)
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
transcript = await self.get_transcript(session)
|
||||
|
||||
return await transcripts_controller.append_event(
|
||||
transcript=transcript, event="WAVEFORM", data=waveform
|
||||
session, transcript=transcript, event="WAVEFORM", data=waveform
|
||||
)
|
||||
|
||||
|
||||
@@ -535,7 +554,8 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
@@ -572,13 +592,20 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
recording = None
|
||||
try:
|
||||
if transcript.recording_id:
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
if recording and recording.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||
if meeting:
|
||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||
meeting.id
|
||||
async with get_session_factory()() as session:
|
||||
recording = await recordings_controller.get_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if recording and recording.meeting_id:
|
||||
meeting = await meetings_controller.get_by_id(
|
||||
session, recording.meeting_id
|
||||
)
|
||||
if meeting:
|
||||
consent_denied = (
|
||||
await meeting_consent_controller.has_any_denial(
|
||||
session, meeting.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||
consent_denied = True
|
||||
@@ -606,7 +633,10 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||
|
||||
# non-transactional, files marked for deletion not actually deleted is possible
|
||||
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"audio_deleted": True}
|
||||
)
|
||||
# 2. Delete processed audio from transcript storage S3 bucket
|
||||
if transcript.audio_location == "storage":
|
||||
storage = get_transcripts_storage()
|
||||
@@ -638,21 +668,24 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
logger.info("Transcript has no recording")
|
||||
return
|
||||
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
if not recording:
|
||||
logger.info("Recording not found")
|
||||
return
|
||||
async with get_session_factory()() as session:
|
||||
recording = await recordings_controller.get_by_id(
|
||||
session, transcript.recording_id
|
||||
)
|
||||
if not recording:
|
||||
logger.info("Recording not found")
|
||||
return
|
||||
|
||||
if not recording.meeting_id:
|
||||
logger.info("Recording has no meeting")
|
||||
return
|
||||
if not recording.meeting_id:
|
||||
logger.info("Recording has no meeting")
|
||||
return
|
||||
|
||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||
if not meeting:
|
||||
logger.info("No meeting found for this recording")
|
||||
return
|
||||
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
||||
if not meeting:
|
||||
logger.info("No meeting found for this recording")
|
||||
return
|
||||
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||
if not room:
|
||||
logger.error(f"Missing room for a meeting {meeting.id}")
|
||||
return
|
||||
@@ -677,9 +710,10 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||
response = await send_message_to_zulip(
|
||||
room.zulip_stream, room.zulip_topic, message
|
||||
)
|
||||
await transcripts_controller.update(
|
||||
transcript, {"zulip_message_id": response["id"]}
|
||||
)
|
||||
async with get_session_factory()() as session:
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"zulip_message_id": response["id"]}
|
||||
)
|
||||
|
||||
logger.info("Posted to zulip")
|
||||
|
||||
|
||||
@@ -8,9 +8,10 @@ from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import LockError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db import get_session, get_session_factory
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -185,7 +186,7 @@ async def rooms_list(
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as session:
|
||||
query = await rooms_controller.get_all(
|
||||
user_id=user_id, order_by="-created_at", return_query=True
|
||||
session, user_id=user_id, order_by="-created_at", return_query=True
|
||||
)
|
||||
return await paginate(session, query)
|
||||
|
||||
@@ -194,9 +195,10 @@ async def rooms_list(
|
||||
async def rooms_get(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
||||
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
return room
|
||||
@@ -206,9 +208,10 @@ async def rooms_get(
|
||||
async def rooms_get_by_name(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
@@ -230,10 +233,12 @@ async def rooms_get_by_name(
|
||||
async def rooms_create(
|
||||
room: CreateRoom,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
return await rooms_controller.add(
|
||||
session,
|
||||
name=room.name,
|
||||
user_id=user_id,
|
||||
zulip_auto_post=room.zulip_auto_post,
|
||||
@@ -257,13 +262,14 @@ async def rooms_update(
|
||||
room_id: str,
|
||||
info: UpdateRoom,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
||||
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
values = info.dict(exclude_unset=True)
|
||||
await rooms_controller.update(room, values)
|
||||
await rooms_controller.update(session, room, values)
|
||||
return room
|
||||
|
||||
|
||||
@@ -271,12 +277,13 @@ async def rooms_update(
|
||||
async def rooms_delete(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_id(room_id, user_id=user_id)
|
||||
room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
await rooms_controller.remove_by_id(room.id, user_id=user_id)
|
||||
await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
@@ -285,9 +292,10 @@ async def rooms_create_meeting(
|
||||
room_name: str,
|
||||
info: CreateRoomMeeting,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
@@ -303,7 +311,7 @@ async def rooms_create_meeting(
|
||||
meeting = None
|
||||
if not info.allow_duplicated:
|
||||
meeting = await meetings_controller.get_active(
|
||||
room=room, current_time=current_time
|
||||
session, room=room, current_time=current_time
|
||||
)
|
||||
|
||||
if meeting is None:
|
||||
@@ -314,6 +322,7 @@ async def rooms_create_meeting(
|
||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||
|
||||
meeting = await meetings_controller.create(
|
||||
session,
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
@@ -340,11 +349,12 @@ async def rooms_create_meeting(
|
||||
async def rooms_test_webhook(
|
||||
room_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Test webhook configuration by sending a sample payload."""
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
room = await rooms_controller.get_by_id(session, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
@@ -361,9 +371,10 @@ async def rooms_test_webhook(
|
||||
async def rooms_sync_ics(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -390,9 +401,10 @@ async def rooms_sync_ics(
|
||||
async def rooms_ics_status(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -407,7 +419,7 @@ async def rooms_ics_status(
|
||||
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
session, room.id, include_deleted=False
|
||||
)
|
||||
|
||||
return ICSStatus(
|
||||
@@ -423,15 +435,16 @@ async def rooms_ics_status(
|
||||
async def rooms_list_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
session, room.id, include_deleted=False
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
@@ -449,15 +462,16 @@ async def rooms_list_upcoming_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
minutes_ahead: int = 120,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
room.id, minutes_ahead=minutes_ahead
|
||||
session, room.id, minutes_ahead=minutes_ahead
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
@@ -472,16 +486,17 @@ async def rooms_list_upcoming_meetings(
|
||||
async def rooms_list_active_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
meetings = await meetings_controller.get_all_active_for_room(
|
||||
room=room, current_time=current_time
|
||||
session, room=room, current_time=current_time
|
||||
)
|
||||
|
||||
# Hide host URLs from non-owners
|
||||
@@ -497,15 +512,16 @@ async def rooms_get_meeting(
|
||||
room_name: str,
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get a single meeting by ID within a specific room."""
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
@@ -525,14 +541,15 @@ async def rooms_join_meeting(
|
||||
room_name: str,
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
room = await rooms_controller.get_by_name(session, room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
||||
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
@@ -9,8 +9,10 @@ from typing import Annotated, Optional
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from jose import jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import ALGORITHM
|
||||
@@ -48,7 +50,7 @@ async def transcript_get_audio_mp3(
|
||||
raise unauthorized_exception
|
||||
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.audio_location == "storage":
|
||||
@@ -96,10 +98,11 @@ async def transcript_get_audio_mp3(
|
||||
async def transcript_get_audio_waveform(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AudioWaveform:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript.audio_waveform_filename.exists():
|
||||
|
||||
@@ -8,8 +8,10 @@ from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
|
||||
from reflector.views.types import DeletionStatus
|
||||
|
||||
@@ -37,10 +39,11 @@ class UpdateParticipant(BaseModel):
|
||||
async def transcript_get_participants(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Participant]:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.participants is None:
|
||||
@@ -57,10 +60,11 @@ async def transcript_add_participant(
|
||||
transcript_id: str,
|
||||
participant: CreateParticipant,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# ensure the speaker is unique
|
||||
@@ -83,10 +87,11 @@ async def transcript_get_participant(
|
||||
transcript_id: str,
|
||||
participant_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
for p in transcript.participants:
|
||||
@@ -102,10 +107,11 @@ async def transcript_update_participant(
|
||||
participant_id: str,
|
||||
participant: UpdateParticipant,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Participant:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# ensure the speaker is unique
|
||||
@@ -139,10 +145,11 @@ async def transcript_delete_participant(
|
||||
transcript_id: str,
|
||||
participant_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> DeletionStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
await transcripts_controller.delete_participant(transcript, participant_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
@@ -3,8 +3,10 @@ from typing import Annotated, Optional
|
||||
import celery
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
|
||||
@@ -19,10 +21,11 @@ class ProcessStatus(BaseModel):
|
||||
async def transcript_process(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
|
||||
@@ -8,8 +8,10 @@ from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
router = APIRouter()
|
||||
@@ -36,10 +38,11 @@ async def transcript_assign_speaker(
|
||||
transcript_id: str,
|
||||
assignment: SpeakerAssignment,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> SpeakerAssignmentStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript:
|
||||
@@ -100,6 +103,7 @@ async def transcript_assign_speaker(
|
||||
for topic in changed_topics:
|
||||
transcript.upsert_topic(topic)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": transcript.topics_dump(),
|
||||
@@ -114,10 +118,11 @@ async def transcript_merge_speaker(
|
||||
transcript_id: str,
|
||||
merge: SpeakerMerge,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> SpeakerAssignmentStatus:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript:
|
||||
@@ -163,6 +168,7 @@ async def transcript_merge_speaker(
|
||||
for topic in changed_topics:
|
||||
transcript.upsert_topic(topic)
|
||||
await transcripts_controller.update(
|
||||
session,
|
||||
transcript,
|
||||
{
|
||||
"topics": transcript.topics_dump(),
|
||||
|
||||
@@ -3,8 +3,10 @@ from typing import Annotated, Optional
|
||||
import av
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
|
||||
@@ -22,10 +24,11 @@ async def transcript_record_upload(
|
||||
total_chunks: int,
|
||||
chunk: UploadFile,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
@@ -89,7 +92,7 @@ async def transcript_record_upload(
|
||||
container.close()
|
||||
|
||||
# set the status to "uploaded"
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||
|
||||
# launch a background task to process the file
|
||||
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_session
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||
@@ -16,10 +18,11 @@ async def transcript_record_webrtc(
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
session, transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
|
||||
@@ -24,7 +24,7 @@ async def transcript_events_websocket(
|
||||
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
# user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -34,382 +32,283 @@ def docker_compose_file(pytestconfig):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_service(docker_ip, docker_services):
|
||||
"""Ensure that PostgreSQL service is up and responsive."""
|
||||
port = docker_services.port_for("postgres_test", 5432)
|
||||
|
||||
def is_responsive():
|
||||
try:
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=docker_ip,
|
||||
port=port,
|
||||
dbname="reflector_test",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
)
|
||||
conn.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
|
||||
|
||||
# Return connection parameters
|
||||
return {
|
||||
"host": docker_ip,
|
||||
"port": port,
|
||||
"dbname": "reflector_test",
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
}
|
||||
def docker_ip():
|
||||
"""Get Docker IP address for test services"""
|
||||
# For most Docker setups, localhost works
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@pytest.mark.asyncio
|
||||
# Only register docker_services dependent fixtures if docker plugin is available
|
||||
try:
|
||||
import pytest_docker # noqa: F401
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_service(docker_ip, docker_services):
|
||||
"""Ensure that PostgreSQL service is up and responsive."""
|
||||
port = docker_services.port_for("postgres_test", 5432)
|
||||
|
||||
def is_responsive():
|
||||
try:
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=docker_ip,
|
||||
port=port,
|
||||
dbname="reflector_test",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
)
|
||||
conn.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
docker_services.wait_until_responsive(
|
||||
timeout=30.0, pause=0.1, check=is_responsive
|
||||
)
|
||||
|
||||
# Return connection parameters
|
||||
return {
|
||||
"host": docker_ip,
|
||||
"port": port,
|
||||
"database": "reflector_test",
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
}
|
||||
except ImportError:
|
||||
# Docker plugin not available, provide a dummy fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_service(docker_ip):
|
||||
"""Dummy postgres service when docker plugin is not available"""
|
||||
return {
|
||||
"host": docker_ip,
|
||||
"port": 15432, # Default test postgres port
|
||||
"database": "reflector_test",
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def setup_database(postgres_service):
|
||||
from reflector.db import get_engine
|
||||
from reflector.db.base import metadata
|
||||
"""Setup database and run migrations"""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
async_engine = get_engine()
|
||||
from reflector.db import Base
|
||||
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(metadata.drop_all)
|
||||
await conn.run_sync(metadata.create_all)
|
||||
# Build database URL from connection params
|
||||
db_config = postgres_service
|
||||
DATABASE_URL = (
|
||||
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
|
||||
f"@{db_config['host']}:{db_config['port']}/{db_config['database']}"
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await async_engine.dispose()
|
||||
# Override settings
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.DATABASE_URL = DATABASE_URL
|
||||
|
||||
# Create engine and tables
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# Drop all tables first to ensure clean state
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
# Create all tables
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session():
|
||||
async def session(setup_database):
|
||||
"""Provide a transactional database session for tests"""
|
||||
from reflector.db import get_session_factory
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_processors():
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic"
|
||||
) as mock_topic,
|
||||
patch(
|
||||
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title"
|
||||
) as mock_title,
|
||||
patch(
|
||||
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_long_summary"
|
||||
) as mock_long_summary,
|
||||
patch(
|
||||
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
|
||||
) as mock_short_summary,
|
||||
):
|
||||
from reflector.processors.transcript_topic_detector import TopicResponse
|
||||
|
||||
mock_topic.return_value = TopicResponse(
|
||||
title="LLM TITLE", summary="LLM SUMMARY"
|
||||
)
|
||||
mock_title.return_value = "LLM Title"
|
||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
||||
yield (
|
||||
mock_topic,
|
||||
mock_title,
|
||||
mock_long_summary,
|
||||
mock_short_summary,
|
||||
) # noqa
|
||||
def fake_mp3_upload(tmp_path):
|
||||
"""Create a temporary MP3 file for upload testing"""
|
||||
mp3_file = tmp_path / "test.mp3"
|
||||
# Create a minimal valid MP3 file (ID3v2 header + minimal frame)
|
||||
mp3_data = b"ID3\x04\x00\x00\x00\x00\x00\x00" + b"\xff\xfb" + b"\x00" * 100
|
||||
mp3_file.write_bytes(mp3_data)
|
||||
return mp3_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def whisper_transcript():
|
||||
from reflector.processors.audio_transcript_whisper import (
|
||||
AudioTranscriptWhisperProcessor,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.audio_transcript_auto"
|
||||
".AudioTranscriptAutoProcessor.__new__"
|
||||
) as mock_audio:
|
||||
mock_audio.return_value = AudioTranscriptWhisperProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_transcript():
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
|
||||
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
||||
_time_idx = 0
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
i = self._time_idx
|
||||
self._time_idx += 2
|
||||
return Transcript(
|
||||
text="Hello world.",
|
||||
words=[
|
||||
Word(start=i, end=i + 1, text="Hello", speaker=0),
|
||||
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.audio_transcript_auto"
|
||||
".AudioTranscriptAutoProcessor.__new__"
|
||||
) as mock_audio:
|
||||
mock_audio.return_value = TestAudioTranscriptProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_diarization():
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
|
||||
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
|
||||
_time_idx = 0
|
||||
|
||||
async def _diarize(self, data):
|
||||
i = self._time_idx
|
||||
self._time_idx += 2
|
||||
return [
|
||||
{"start": i, "end": i + 1, "speaker": 0},
|
||||
{"start": i + 1, "end": i + 2, "speaker": 1},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"reflector.processors.audio_diarization_auto"
|
||||
".AudioDiarizationAutoProcessor.__new__"
|
||||
) as mock_audio:
|
||||
mock_audio.return_value = TestAudioDiarizationProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_file_transcript():
|
||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||
def dummy_transcript():
|
||||
"""Mock transcript processor response"""
|
||||
from reflector.processors.types import Transcript, Word
|
||||
|
||||
class TestFileTranscriptProcessor(FileTranscriptProcessor):
|
||||
async def _transcript(self, data):
|
||||
return Transcript(
|
||||
text="Hello world. How are you today?",
|
||||
words=[
|
||||
Word(start=0.0, end=0.5, text="Hello", speaker=0),
|
||||
Word(start=0.5, end=0.6, text=" ", speaker=0),
|
||||
Word(start=0.6, end=1.0, text="world", speaker=0),
|
||||
Word(start=1.0, end=1.1, text=".", speaker=0),
|
||||
Word(start=1.1, end=1.2, text=" ", speaker=0),
|
||||
Word(start=1.2, end=1.5, text="How", speaker=0),
|
||||
Word(start=1.5, end=1.6, text=" ", speaker=0),
|
||||
Word(start=1.6, end=1.8, text="are", speaker=0),
|
||||
Word(start=1.8, end=1.9, text=" ", speaker=0),
|
||||
Word(start=1.9, end=2.1, text="you", speaker=0),
|
||||
Word(start=2.1, end=2.2, text=" ", speaker=0),
|
||||
Word(start=2.2, end=2.5, text="today", speaker=0),
|
||||
Word(start=2.5, end=2.6, text="?", speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
|
||||
) as mock_auto:
|
||||
mock_auto.return_value = TestFileTranscriptProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_file_diarization():
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
return Transcript(
|
||||
text="Hello world this is a test",
|
||||
words=[
|
||||
Word(word="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(word="world", start=0.5, end=1.0, speaker=0),
|
||||
Word(word="this", start=1.0, end=1.5, speaker=0),
|
||||
Word(word="is", start=1.5, end=1.8, speaker=0),
|
||||
Word(word="a", start=1.8, end=2.0, speaker=0),
|
||||
Word(word="test", start=2.0, end=2.5, speaker=0),
|
||||
],
|
||||
)
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
class TestFileDiarizationProcessor(FileDiarizationProcessor):
|
||||
async def _diarize(self, data):
|
||||
return FileDiarizationOutput(
|
||||
diarization=[
|
||||
DiarizationSegment(start=0.0, end=1.1, speaker=0),
|
||||
DiarizationSegment(start=1.2, end=2.6, speaker=1),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
|
||||
) as mock_auto:
|
||||
mock_auto.return_value = TestFileDiarizationProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_transcript_translator():
|
||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||
|
||||
class TestTranscriptTranslatorProcessor(TranscriptTranslatorProcessor):
|
||||
async def _translate(self, text: str) -> str:
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
return f"{source_language}:{target_language}:{text}"
|
||||
|
||||
def mock_new(cls, *args, **kwargs):
|
||||
return TestTranscriptTranslatorProcessor(*args, **kwargs)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.transcript_translator_auto"
|
||||
".TranscriptTranslatorAutoProcessor.__new__",
|
||||
mock_new,
|
||||
):
|
||||
yield
|
||||
def dummy_transcript_translator():
|
||||
"""Mock transcript translation"""
|
||||
return "Hola mundo esto es una prueba"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm import LLM
|
||||
def dummy_diarization():
|
||||
"""Mock diarization processor response"""
|
||||
from reflector.processors.types import DiarizationOutput, DiarizationSegment
|
||||
|
||||
class TestLLM(LLM):
|
||||
def __init__(self):
|
||||
self.model_name = "DUMMY MODEL"
|
||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
||||
|
||||
# LLM doesn't have get_instance anymore, mocking constructor instead
|
||||
with patch("reflector.llm.LLM") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
yield
|
||||
return DiarizationOutput(
|
||||
diarization=[
|
||||
DiarizationSegment(speaker=0, start=0.0, end=1.0),
|
||||
DiarizationSegment(speaker=1, start=1.0, end=2.5),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_storage():
|
||||
from reflector.storage.base import Storage
|
||||
def dummy_file_transcript():
|
||||
"""Mock file transcript processor response"""
|
||||
from reflector.processors.types import Transcript, Word
|
||||
|
||||
class DummyStorage(Storage):
|
||||
async def _put_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def _delete_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def _get_file_url(self, *args, **kwargs):
|
||||
return "http://fake_server/audio.mp3"
|
||||
|
||||
async def _get_file(self, *args, **kwargs):
|
||||
from pathlib import Path
|
||||
|
||||
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
return test_mp3.read_bytes()
|
||||
|
||||
dummy = DummyStorage()
|
||||
with (
|
||||
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
|
||||
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
|
||||
patch(
|
||||
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
|
||||
) as mock_get_transcripts2,
|
||||
):
|
||||
mock_storage.return_value = dummy
|
||||
mock_get_transcripts.return_value = dummy
|
||||
mock_get_transcripts2.return_value = dummy
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_enable_logging():
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_config():
|
||||
with NamedTemporaryFile() as f:
|
||||
yield {
|
||||
"broker_url": "memory://",
|
||||
"result_backend": f"db+sqlite:///{f.name}",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_includes():
|
||||
return [
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
"reflector.pipelines.main_file_pipeline",
|
||||
]
|
||||
return Transcript(
|
||||
text="This is a complete file transcript with multiple speakers",
|
||||
words=[
|
||||
Word(word="This", start=0.0, end=0.5, speaker=0),
|
||||
Word(word="is", start=0.5, end=0.8, speaker=0),
|
||||
Word(word="a", start=0.8, end=1.0, speaker=0),
|
||||
Word(word="complete", start=1.0, end=1.5, speaker=1),
|
||||
Word(word="file", start=1.5, end=1.8, speaker=1),
|
||||
Word(word="transcript", start=1.8, end=2.3, speaker=1),
|
||||
Word(word="with", start=2.3, end=2.5, speaker=0),
|
||||
Word(word="multiple", start=2.5, end=3.0, speaker=0),
|
||||
Word(word="speakers", start=3.0, end=3.5, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
from httpx import AsyncClient
|
||||
def dummy_file_diarization():
|
||||
"""Mock file diarization processor response"""
|
||||
from reflector.processors.types import DiarizationOutput, DiarizationSegment
|
||||
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fake_mp3_upload():
|
||||
with patch(
|
||||
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
|
||||
) as mock_move:
|
||||
mock_move.return_value = True
|
||||
yield
|
||||
return DiarizationOutput(
|
||||
diarization=[
|
||||
DiarizationSegment(speaker=0, start=0.0, end=1.0),
|
||||
DiarizationSegment(speaker=1, start=1.0, end=2.3),
|
||||
DiarizationSegment(speaker=0, start=2.3, end=3.5),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fake_transcript_with_topics(tmpdir, client):
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
def fake_transcript_with_topics():
|
||||
"""Create a transcript with topics for testing"""
|
||||
from reflector.db.transcripts import TranscriptTopic
|
||||
from reflector.processors.types import Word
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import transcripts_controller
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
# create a transcript
|
||||
response = await client.post("/transcripts", json={"name": "Test audio download"})
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
assert transcript is not None
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||
|
||||
# manually copy a file at the expected location
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(path, audio_filename)
|
||||
|
||||
# create some topics
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
title="Topic 1",
|
||||
summary="Topic 1 summary",
|
||||
timestamp=0,
|
||||
transcript="Hello world",
|
||||
id="topic1",
|
||||
title="Introduction",
|
||||
summary="Opening remarks and introductions",
|
||||
timestamp=0.0,
|
||||
duration=30.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0, end=1, speaker=0),
|
||||
Word(text="world", start=1, end=2, speaker=0),
|
||||
Word(word="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(word="everyone", start=0.5, end=1.0, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 2",
|
||||
summary="Topic 2 summary",
|
||||
timestamp=2,
|
||||
transcript="Hello world",
|
||||
id="topic2",
|
||||
title="Main Discussion",
|
||||
summary="Core topics and key points",
|
||||
timestamp=30.0,
|
||||
duration=60.0,
|
||||
words=[
|
||||
Word(text="Hello", start=2, end=3, speaker=0),
|
||||
Word(text="world", start=3, end=4, speaker=0),
|
||||
Word(word="Let's", start=30.0, end=30.3, speaker=1),
|
||||
Word(word="discuss", start=30.3, end=30.8, speaker=1),
|
||||
Word(word="the", start=30.8, end=31.0, speaker=1),
|
||||
Word(word="agenda", start=31.0, end=31.5, speaker=1),
|
||||
],
|
||||
),
|
||||
)
|
||||
]
|
||||
return topics
|
||||
|
||||
yield transcript
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_processors(
|
||||
dummy_transcript,
|
||||
dummy_transcript_translator,
|
||||
dummy_diarization,
|
||||
dummy_file_transcript,
|
||||
dummy_file_diarization,
|
||||
):
|
||||
"""Mock all processor responses"""
|
||||
return {
|
||||
"transcript": dummy_transcript,
|
||||
"translator": dummy_transcript_translator,
|
||||
"diarization": dummy_diarization,
|
||||
"file_transcript": dummy_file_transcript,
|
||||
"file_diarization": dummy_file_diarization,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_storage():
|
||||
"""Mock storage backend"""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
storage = AsyncMock()
|
||||
storage.get_file_url.return_value = "https://example.com/test-audio.mp3"
|
||||
storage.put_file.return_value = None
|
||||
storage.delete_file.return_value = None
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_llm():
|
||||
"""Mock LLM responses"""
|
||||
return {
|
||||
"title": "Test Meeting Title",
|
||||
"summary": "This is a test meeting summary with key discussion points.",
|
||||
"short_summary": "Brief test summary.",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def whisper_transcript():
|
||||
"""Mock Whisper API response format"""
|
||||
return {
|
||||
"text": "Hello world this is a test",
|
||||
"segments": [
|
||||
{
|
||||
"start": 0.0,
|
||||
"end": 2.5,
|
||||
"text": "Hello world this is a test",
|
||||
"words": [
|
||||
{"word": "Hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0},
|
||||
{"word": "this", "start": 1.0, "end": 1.5},
|
||||
{"word": "is", "start": 1.5, "end": 1.8},
|
||||
{"word": "a", "start": 1.8, "end": 2.0},
|
||||
{"word": "test", "start": 2.0, "end": 2.5},
|
||||
],
|
||||
}
|
||||
],
|
||||
"language": "en",
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db import get_session_factory
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.services.ics_sync import ICSSyncService
|
||||
|
||||
@@ -17,21 +18,22 @@ async def test_attendee_parsing_bug():
|
||||
instead of properly parsed email addresses.
|
||||
"""
|
||||
# Create a test room
|
||||
room = await rooms_controller.add(
|
||||
session,
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="http://test.com/test.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
async with get_session_factory()() as session:
|
||||
room = await rooms_controller.add(
|
||||
session,
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="http://test.com/test.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
# Read the test ICS file that reproduces the bug and update it with current time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -95,99 +97,14 @@ async def test_attendee_parsing_bug():
|
||||
# This is where the bug manifests - check the attendees
|
||||
attendees = event["attendees"]
|
||||
|
||||
# Print attendee info for debugging
|
||||
print(f"Number of attendees found: {len(attendees)}")
|
||||
# Debug output to see what's happening
|
||||
print(f"Number of attendees: {len(attendees)}")
|
||||
for i, attendee in enumerate(attendees):
|
||||
print(
|
||||
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
|
||||
)
|
||||
print(f"Attendee {i}: {attendee}")
|
||||
|
||||
# With the fix, we should now get properly parsed email addresses
|
||||
# Check that no single characters are parsed as emails
|
||||
single_char_emails = [
|
||||
att for att in attendees if att.get("email") and len(att["email"]) == 1
|
||||
]
|
||||
# The bug would cause 29 attendees (length of "MAILIN01234567890@allo.coop")
|
||||
# instead of 1 attendee
|
||||
assert len(attendees) == 1, f"Expected 1 attendee, got {len(attendees)}"
|
||||
|
||||
if single_char_emails:
|
||||
print(
|
||||
f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
|
||||
)
|
||||
for att in single_char_emails:
|
||||
print(f" - '{att['email']}'")
|
||||
|
||||
# Should have attendees but not single-character emails
|
||||
assert len(attendees) > 0
|
||||
assert (
|
||||
len(single_char_emails) == 0
|
||||
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
|
||||
|
||||
# Check that all emails are valid (contain @ symbol)
|
||||
valid_emails = [
|
||||
att for att in attendees if att.get("email") and "@" in att["email"]
|
||||
]
|
||||
assert len(valid_emails) == len(
|
||||
attendees
|
||||
), "Some attendees don't have valid email addresses"
|
||||
|
||||
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
|
||||
assert (
|
||||
len(attendees) >= 25
|
||||
), f"Expected around 29 attendees, got {len(attendees)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_correct_attendee_parsing():
|
||||
"""
|
||||
Test what correct attendee parsing should look like.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from icalendar import Event
|
||||
|
||||
from reflector.services.ics_sync import ICSFetchService
|
||||
|
||||
service = ICSFetchService()
|
||||
|
||||
# Create a properly formatted event with multiple attendees
|
||||
event = Event()
|
||||
event.add("uid", "test-correct-attendees")
|
||||
event.add("summary", "Test Meeting")
|
||||
event.add("location", "http://test.com/test")
|
||||
event.add("dtstart", datetime.now(timezone.utc))
|
||||
event.add("dtend", datetime.now(timezone.utc))
|
||||
|
||||
# Add attendees the correct way (separate ATTENDEE lines)
|
||||
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
|
||||
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
|
||||
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
|
||||
event.add(
|
||||
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
|
||||
)
|
||||
|
||||
# Parse the event
|
||||
result = service._parse_event(event)
|
||||
|
||||
assert result is not None
|
||||
attendees = result["attendees"]
|
||||
|
||||
# Should have 4 attendees (3 attendees + 1 organizer)
|
||||
assert len(attendees) == 4
|
||||
|
||||
# Check that all emails are valid email addresses
|
||||
emails = [att["email"] for att in attendees if att.get("email")]
|
||||
expected_emails = [
|
||||
"alice@example.com",
|
||||
"bob@example.com",
|
||||
"charlie@example.com",
|
||||
"organizer@example.com",
|
||||
]
|
||||
|
||||
for email in emails:
|
||||
assert "@" in email, f"Invalid email format: {email}"
|
||||
assert len(email) > 5, f"Email too short: {email}"
|
||||
|
||||
# Check that we have the expected emails
|
||||
assert "alice@example.com" in emails
|
||||
assert "bob@example.com" in emails
|
||||
assert "charlie@example.com" in emails
|
||||
assert "organizer@example.com" in emails
|
||||
# Verify the single attendee has correct email
|
||||
assert attendees[0]["email"] == "MAILIN01234567890@allo.coop"
|
||||
|
||||
@@ -8,6 +8,7 @@ from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
TranscriptController,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
|
||||
Reference in New Issue
Block a user