mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
fix: Complete SQLAlchemy 2.0 migration - add session parameters to all controller calls
- Add session parameter to all view functions and controller calls - Fix pipeline files to use get_session_factory() for background tasks - Update PipelineMainBase and PipelineMainFile to handle sessions properly - Add missing on_* methods to PipelineMainFile class - Fix test fixtures to handle docker services availability - Add docker_ip fixture for test database connections - Import fixes for transcripts_controller in tests All controller calls now properly use sessions as first parameter per SQLAlchemy 2.0 async patterns.
This commit is contained in:
@@ -8,18 +8,22 @@ Uses parallel processing for transcription, diarization, and waveform generation
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import structlog
|
import structlog
|
||||||
from celery import chain, shared_task
|
from celery import chain, shared_task
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
|
from reflector.db import get_session_factory
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
SourceKind,
|
SourceKind,
|
||||||
Transcript,
|
Transcript,
|
||||||
TranscriptStatus,
|
TranscriptStatus,
|
||||||
|
TranscriptTopic,
|
||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
)
|
)
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
@@ -83,6 +87,32 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||||
|
|
||||||
|
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
||||||
|
"""Get transcript with session"""
|
||||||
|
if session:
|
||||||
|
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||||
|
else:
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
result = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
raise Exception("Transcript not found")
|
||||||
|
return result
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lock_transaction(self):
|
||||||
|
# This lock is to prevent multiple processor starting adding
|
||||||
|
# into event array at the same time
|
||||||
|
async with asyncio.Lock():
|
||||||
|
yield
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def transaction(self):
|
||||||
|
async with self.lock_transaction():
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
@@ -97,17 +127,23 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||||
async with self.lock_transaction():
|
async with self.lock_transaction():
|
||||||
return await transcripts_controller.set_status(transcript_id, status)
|
async with get_session_factory()() as session:
|
||||||
|
return await transcripts_controller.set_status(
|
||||||
|
session, transcript_id, status
|
||||||
|
)
|
||||||
|
|
||||||
async def process(self, file_path: Path):
|
async def process(self, file_path: Path):
|
||||||
"""Main entry point for file processing"""
|
"""Main entry point for file processing"""
|
||||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||||
|
|
||||||
transcript = await self.get_transcript()
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
# Clear transcript as we're going to regenerate everything
|
# Clear transcript as we're going to regenerate everything
|
||||||
async with self.transaction():
|
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"events": [],
|
"events": [],
|
||||||
@@ -131,7 +167,8 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
self.logger.info("File pipeline complete")
|
self.logger.info("File pipeline complete")
|
||||||
|
|
||||||
await transcripts_controller.set_status(transcript.id, "ended")
|
async with get_session_factory()() as session:
|
||||||
|
await transcripts_controller.set_status(session, transcript.id, "ended")
|
||||||
|
|
||||||
async def extract_and_write_audio(
|
async def extract_and_write_audio(
|
||||||
self, file_path: Path, transcript: Transcript
|
self, file_path: Path, transcript: Transcript
|
||||||
@@ -308,7 +345,10 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
async def generate_waveform(self, audio_path: Path):
|
async def generate_waveform(self, audio_path: Path):
|
||||||
"""Generate and save waveform"""
|
"""Generate and save waveform"""
|
||||||
transcript = await self.get_transcript()
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
processor = AudioWaveformProcessor(
|
processor = AudioWaveformProcessor(
|
||||||
audio_path=audio_path,
|
audio_path=audio_path,
|
||||||
@@ -367,7 +407,10 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
self.logger.warning("No topics for summary generation")
|
self.logger.warning("No topics for summary generation")
|
||||||
return
|
return
|
||||||
|
|
||||||
transcript = await self.get_transcript()
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
processor = TranscriptFinalSummaryProcessor(
|
processor = TranscriptFinalSummaryProcessor(
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
callback=self.on_long_summary,
|
callback=self.on_long_summary,
|
||||||
@@ -380,37 +423,144 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
await processor.flush()
|
await processor.flush()
|
||||||
|
|
||||||
|
async def on_topic(self, topic: TitleSummary):
|
||||||
|
"""Handle topic event - save to database"""
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
topic_obj = TranscriptTopic(
|
||||||
|
title=topic.title,
|
||||||
|
summary=topic.summary,
|
||||||
|
timestamp=topic.timestamp,
|
||||||
|
duration=topic.duration,
|
||||||
|
)
|
||||||
|
await transcripts_controller.upsert_topic(session, transcript, topic_obj)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
|
transcript=transcript,
|
||||||
|
event="TOPIC",
|
||||||
|
data=topic_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_title(self, data):
|
||||||
|
"""Handle title event"""
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
if not transcript.title:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
|
transcript,
|
||||||
|
{"title": data.title},
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_TITLE",
|
||||||
|
data={"title": data.title},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_long_summary(self, data):
|
||||||
|
"""Handle long summary event"""
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
|
transcript,
|
||||||
|
{"long_summary": data.long_summary},
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_LONG_SUMMARY",
|
||||||
|
data={"long_summary": data.long_summary},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_short_summary(self, data):
|
||||||
|
"""Handle short summary event"""
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
|
transcript,
|
||||||
|
{"short_summary": data.short_summary},
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_SHORT_SUMMARY",
|
||||||
|
data={"short_summary": data.short_summary},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_duration(self, duration):
|
||||||
|
"""Handle duration event"""
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
|
transcript,
|
||||||
|
{"duration": duration},
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
|
transcript=transcript,
|
||||||
|
event="DURATION",
|
||||||
|
data={"duration": duration},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_waveform(self, waveform):
|
||||||
|
"""Handle waveform event"""
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
|
await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
|
transcript=transcript,
|
||||||
|
event="WAVEFORM",
|
||||||
|
data={"waveform": waveform},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def task_send_webhook_if_needed(*, transcript_id: str):
|
async def task_send_webhook_if_needed(*, transcript_id: str):
|
||||||
"""Send webhook if this is a room recording with webhook configured"""
|
"""Send webhook if this is a room recording with webhook configured"""
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
async with get_session_factory()() as session:
|
||||||
if not transcript:
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
return
|
if not transcript:
|
||||||
|
return
|
||||||
|
|
||||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||||
room = await rooms_controller.get_by_id(transcript.room_id)
|
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
||||||
if room and room.webhook_url:
|
if room and room.webhook_url:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Dispatching webhook",
|
"Dispatching webhook",
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
room_id=room.id,
|
room_id=room.id,
|
||||||
webhook_url=room.webhook_url,
|
webhook_url=room.webhook_url,
|
||||||
)
|
)
|
||||||
send_transcript_webhook.delay(
|
send_transcript_webhook.delay(
|
||||||
transcript_id, room.id, event_id=uuid.uuid4().hex
|
transcript_id, room.id, event_id=uuid.uuid4().hex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
async def task_pipeline_file_process(*, transcript_id: str):
|
async def task_pipeline_file_process(*, transcript_id: str):
|
||||||
"""Celery task for file pipeline processing"""
|
"""Celery task for file pipeline processing"""
|
||||||
|
async with get_session_factory()() as session:
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
|
|
||||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -20,9 +20,11 @@ import av
|
|||||||
import boto3
|
import boto3
|
||||||
from celery import chord, current_task, group, shared_task
|
from celery import chord, current_task, group, shared_task
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from structlog import BoundLogger as Logger
|
from structlog import BoundLogger as Logger
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
|
from reflector.db import get_session_factory
|
||||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||||
from reflector.db.recordings import recordings_controller
|
from reflector.db.recordings import recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -96,9 +98,10 @@ def get_transcript(func):
|
|||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(**kwargs):
|
async def wrapper(**kwargs):
|
||||||
transcript_id = kwargs.pop("transcript_id")
|
transcript_id = kwargs.pop("transcript_id")
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
async with get_session_factory()() as session:
|
||||||
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise Exception("Transcript {transcript_id} not found")
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
|
|
||||||
# Enhanced logger with Celery task context
|
# Enhanced logger with Celery task context
|
||||||
tlogger = logger.bind(transcript_id=transcript.id)
|
tlogger = logger.bind(transcript_id=transcript.id)
|
||||||
@@ -139,11 +142,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
self._ws_manager = get_ws_manager()
|
self._ws_manager = get_ws_manager()
|
||||||
return self._ws_manager
|
return self._ws_manager
|
||||||
|
|
||||||
async def get_transcript(self) -> Transcript:
|
async def get_transcript(self, session: AsyncSession = None) -> Transcript:
|
||||||
# fetch the transcript
|
# fetch the transcript
|
||||||
result = await transcripts_controller.get_by_id(
|
if session:
|
||||||
transcript_id=self.transcript_id
|
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||||
)
|
else:
|
||||||
|
async with get_session_factory()() as session:
|
||||||
|
result = await transcripts_controller.get_by_id(
|
||||||
|
session, self.transcript_id
|
||||||
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise Exception("Transcript not found")
|
raise Exception("Transcript not found")
|
||||||
return result
|
return result
|
||||||
@@ -175,8 +182,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def transaction(self):
|
async def transaction(self):
|
||||||
async with self.lock_transaction():
|
async with self.lock_transaction():
|
||||||
async with transcripts_controller.transaction():
|
async with get_session_factory()() as session:
|
||||||
yield
|
yield session
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_status(self, status):
|
async def on_status(self, status):
|
||||||
@@ -207,13 +214,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
# when the status of the pipeline changes, update the transcript
|
# when the status of the pipeline changes, update the transcript
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
return await transcripts_controller.set_status(self.transcript_id, status)
|
async with get_session_factory()() as session:
|
||||||
|
return await transcripts_controller.set_status(
|
||||||
|
session, self.transcript_id, status
|
||||||
|
)
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_transcript(self, data):
|
async def on_transcript(self, data):
|
||||||
async with self.transaction():
|
async with self.transaction() as session:
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript(session)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TRANSCRIPT",
|
event="TRANSCRIPT",
|
||||||
data=TranscriptText(text=data.text, translation=data.translation),
|
data=TranscriptText(text=data.text, translation=data.translation),
|
||||||
@@ -230,10 +241,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
)
|
)
|
||||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||||
topic.id = data.id
|
topic.id = data.id
|
||||||
async with self.transaction():
|
async with self.transaction() as session:
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript(session)
|
||||||
await transcripts_controller.upsert_topic(transcript, topic)
|
await transcripts_controller.upsert_topic(session, transcript, topic)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TOPIC",
|
event="TOPIC",
|
||||||
data=topic,
|
data=topic,
|
||||||
@@ -242,16 +254,18 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_title(self, data):
|
async def on_title(self, data):
|
||||||
final_title = TranscriptFinalTitle(title=data.title)
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
async with self.transaction():
|
async with self.transaction() as session:
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript(session)
|
||||||
if not transcript.title:
|
if not transcript.title:
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"title": final_title.title,
|
"title": final_title.title,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_TITLE",
|
event="FINAL_TITLE",
|
||||||
data=final_title,
|
data=final_title,
|
||||||
@@ -260,15 +274,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_long_summary(self, data):
|
async def on_long_summary(self, data):
|
||||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||||
async with self.transaction():
|
async with self.transaction() as session:
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript(session)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"long_summary": final_long_summary.long_summary,
|
"long_summary": final_long_summary.long_summary,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_LONG_SUMMARY",
|
event="FINAL_LONG_SUMMARY",
|
||||||
data=final_long_summary,
|
data=final_long_summary,
|
||||||
@@ -279,15 +295,17 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
final_short_summary = TranscriptFinalShortSummary(
|
final_short_summary = TranscriptFinalShortSummary(
|
||||||
short_summary=data.short_summary
|
short_summary=data.short_summary
|
||||||
)
|
)
|
||||||
async with self.transaction():
|
async with self.transaction() as session:
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript(session)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"short_summary": final_short_summary.short_summary,
|
"short_summary": final_short_summary.short_summary,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
|
session,
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_SHORT_SUMMARY",
|
event="FINAL_SHORT_SUMMARY",
|
||||||
data=final_short_summary,
|
data=final_short_summary,
|
||||||
@@ -295,29 +313,30 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_duration(self, data):
|
async def on_duration(self, data):
|
||||||
async with self.transaction():
|
async with self.transaction() as session:
|
||||||
duration = TranscriptDuration(duration=data)
|
duration = TranscriptDuration(duration=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript(session)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"duration": duration.duration,
|
"duration": duration.duration,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
transcript=transcript, event="DURATION", data=duration
|
session, transcript=transcript, event="DURATION", data=duration
|
||||||
)
|
)
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_waveform(self, data):
|
async def on_waveform(self, data):
|
||||||
async with self.transaction():
|
async with self.transaction() as session:
|
||||||
waveform = TranscriptWaveform(waveform=data)
|
waveform = TranscriptWaveform(waveform=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript(session)
|
||||||
|
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
transcript=transcript, event="WAVEFORM", data=waveform
|
session, transcript=transcript, event="WAVEFORM", data=waveform
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -535,7 +554,8 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Upload to external storage and delete the file
|
# Upload to external storage and delete the file
|
||||||
await transcripts_controller.move_mp3_to_storage(transcript)
|
async with get_session_factory()() as session:
|
||||||
|
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
||||||
|
|
||||||
logger.info("Upload mp3 done")
|
logger.info("Upload mp3 done")
|
||||||
|
|
||||||
@@ -572,13 +592,20 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
|||||||
recording = None
|
recording = None
|
||||||
try:
|
try:
|
||||||
if transcript.recording_id:
|
if transcript.recording_id:
|
||||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
async with get_session_factory()() as session:
|
||||||
if recording and recording.meeting_id:
|
recording = await recordings_controller.get_by_id(
|
||||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
session, transcript.recording_id
|
||||||
if meeting:
|
)
|
||||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
if recording and recording.meeting_id:
|
||||||
meeting.id
|
meeting = await meetings_controller.get_by_id(
|
||||||
|
session, recording.meeting_id
|
||||||
)
|
)
|
||||||
|
if meeting:
|
||||||
|
consent_denied = (
|
||||||
|
await meeting_consent_controller.has_any_denial(
|
||||||
|
session, meeting.id
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||||
consent_denied = True
|
consent_denied = True
|
||||||
@@ -606,7 +633,10 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
|||||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||||
|
|
||||||
# non-transactional, files marked for deletion not actually deleted is possible
|
# non-transactional, files marked for deletion not actually deleted is possible
|
||||||
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
async with get_session_factory()() as session:
|
||||||
|
await transcripts_controller.update(
|
||||||
|
session, transcript, {"audio_deleted": True}
|
||||||
|
)
|
||||||
# 2. Delete processed audio from transcript storage S3 bucket
|
# 2. Delete processed audio from transcript storage S3 bucket
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
storage = get_transcripts_storage()
|
storage = get_transcripts_storage()
|
||||||
@@ -638,21 +668,24 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Transcript has no recording")
|
logger.info("Transcript has no recording")
|
||||||
return
|
return
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
async with get_session_factory()() as session:
|
||||||
if not recording:
|
recording = await recordings_controller.get_by_id(
|
||||||
logger.info("Recording not found")
|
session, transcript.recording_id
|
||||||
return
|
)
|
||||||
|
if not recording:
|
||||||
|
logger.info("Recording not found")
|
||||||
|
return
|
||||||
|
|
||||||
if not recording.meeting_id:
|
if not recording.meeting_id:
|
||||||
logger.info("Recording has no meeting")
|
logger.info("Recording has no meeting")
|
||||||
return
|
return
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
||||||
if not meeting:
|
if not meeting:
|
||||||
logger.info("No meeting found for this recording")
|
logger.info("No meeting found for this recording")
|
||||||
return
|
return
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
||||||
if not room:
|
if not room:
|
||||||
logger.error(f"Missing room for a meeting {meeting.id}")
|
logger.error(f"Missing room for a meeting {meeting.id}")
|
||||||
return
|
return
|
||||||
@@ -677,9 +710,10 @@ async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
|||||||
response = await send_message_to_zulip(
|
response = await send_message_to_zulip(
|
||||||
room.zulip_stream, room.zulip_topic, message
|
room.zulip_stream, room.zulip_topic, message
|
||||||
)
|
)
|
||||||
await transcripts_controller.update(
|
async with get_session_factory()() as session:
|
||||||
transcript, {"zulip_message_id": response["id"]}
|
await transcripts_controller.update(
|
||||||
)
|
session, transcript, {"zulip_message_id": response["id"]}
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Posted to zulip")
|
logger.info("Posted to zulip")
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ from fastapi_pagination import Page
|
|||||||
from fastapi_pagination.ext.sqlalchemy import paginate
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis.exceptions import LockError
|
from redis.exceptions import LockError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session_factory
|
from reflector.db import get_session, get_session_factory
|
||||||
from reflector.db.calendar_events import calendar_events_controller
|
from reflector.db.calendar_events import calendar_events_controller
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -185,7 +186,7 @@ async def rooms_list(
|
|||||||
session_factory = get_session_factory()
|
session_factory = get_session_factory()
|
||||||
async with session_factory() as session:
|
async with session_factory() as session:
|
||||||
query = await rooms_controller.get_all(
|
query = await rooms_controller.get_all(
|
||||||
user_id=user_id, order_by="-created_at", return_query=True
|
session, user_id=user_id, order_by="-created_at", return_query=True
|
||||||
)
|
)
|
||||||
return await paginate(session, query)
|
return await paginate(session, query)
|
||||||
|
|
||||||
@@ -194,9 +195,10 @@ async def rooms_list(
|
|||||||
async def rooms_get(
|
async def rooms_get(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
return room
|
return room
|
||||||
@@ -206,9 +208,10 @@ async def rooms_get(
|
|||||||
async def rooms_get_by_name(
|
async def rooms_get_by_name(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
@@ -230,10 +233,12 @@ async def rooms_get_by_name(
|
|||||||
async def rooms_create(
|
async def rooms_create(
|
||||||
room: CreateRoom,
|
room: CreateRoom,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
return await rooms_controller.add(
|
return await rooms_controller.add(
|
||||||
|
session,
|
||||||
name=room.name,
|
name=room.name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
zulip_auto_post=room.zulip_auto_post,
|
zulip_auto_post=room.zulip_auto_post,
|
||||||
@@ -257,13 +262,14 @@ async def rooms_update(
|
|||||||
room_id: str,
|
room_id: str,
|
||||||
info: UpdateRoom,
|
info: UpdateRoom,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
values = info.dict(exclude_unset=True)
|
values = info.dict(exclude_unset=True)
|
||||||
await rooms_controller.update(room, values)
|
await rooms_controller.update(session, room, values)
|
||||||
return room
|
return room
|
||||||
|
|
||||||
|
|
||||||
@@ -271,12 +277,13 @@ async def rooms_update(
|
|||||||
async def rooms_delete(
|
async def rooms_delete(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_id(room_id, user_id=user_id)
|
room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
await rooms_controller.remove_by_id(room.id, user_id=user_id)
|
await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
|
||||||
return DeletionStatus(status="ok")
|
return DeletionStatus(status="ok")
|
||||||
|
|
||||||
|
|
||||||
@@ -285,9 +292,10 @@ async def rooms_create_meeting(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
info: CreateRoomMeeting,
|
info: CreateRoomMeeting,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
@@ -303,7 +311,7 @@ async def rooms_create_meeting(
|
|||||||
meeting = None
|
meeting = None
|
||||||
if not info.allow_duplicated:
|
if not info.allow_duplicated:
|
||||||
meeting = await meetings_controller.get_active(
|
meeting = await meetings_controller.get_active(
|
||||||
room=room, current_time=current_time
|
session, room=room, current_time=current_time
|
||||||
)
|
)
|
||||||
|
|
||||||
if meeting is None:
|
if meeting is None:
|
||||||
@@ -314,6 +322,7 @@ async def rooms_create_meeting(
|
|||||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||||
|
|
||||||
meeting = await meetings_controller.create(
|
meeting = await meetings_controller.create(
|
||||||
|
session,
|
||||||
id=whereby_meeting["meetingId"],
|
id=whereby_meeting["meetingId"],
|
||||||
room_name=whereby_meeting["roomName"],
|
room_name=whereby_meeting["roomName"],
|
||||||
room_url=whereby_meeting["roomUrl"],
|
room_url=whereby_meeting["roomUrl"],
|
||||||
@@ -340,11 +349,12 @@ async def rooms_create_meeting(
|
|||||||
async def rooms_test_webhook(
|
async def rooms_test_webhook(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Test webhook configuration by sending a sample payload."""
|
"""Test webhook configuration by sending a sample payload."""
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(room_id)
|
room = await rooms_controller.get_by_id(session, room_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
@@ -361,9 +371,10 @@ async def rooms_test_webhook(
|
|||||||
async def rooms_sync_ics(
|
async def rooms_sync_ics(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
@@ -390,9 +401,10 @@ async def rooms_sync_ics(
|
|||||||
async def rooms_ics_status(
|
async def rooms_ics_status(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
@@ -407,7 +419,7 @@ async def rooms_ics_status(
|
|||||||
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
||||||
|
|
||||||
events = await calendar_events_controller.get_by_room(
|
events = await calendar_events_controller.get_by_room(
|
||||||
room.id, include_deleted=False
|
session, room.id, include_deleted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return ICSStatus(
|
return ICSStatus(
|
||||||
@@ -423,15 +435,16 @@ async def rooms_ics_status(
|
|||||||
async def rooms_list_meetings(
|
async def rooms_list_meetings(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
events = await calendar_events_controller.get_by_room(
|
events = await calendar_events_controller.get_by_room(
|
||||||
room.id, include_deleted=False
|
session, room.id, include_deleted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id != room.user_id:
|
if user_id != room.user_id:
|
||||||
@@ -449,15 +462,16 @@ async def rooms_list_upcoming_meetings(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
minutes_ahead: int = 120,
|
minutes_ahead: int = 120,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
events = await calendar_events_controller.get_upcoming(
|
events = await calendar_events_controller.get_upcoming(
|
||||||
room.id, minutes_ahead=minutes_ahead
|
session, room.id, minutes_ahead=minutes_ahead
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id != room.user_id:
|
if user_id != room.user_id:
|
||||||
@@ -472,16 +486,17 @@ async def rooms_list_upcoming_meetings(
|
|||||||
async def rooms_list_active_meetings(
|
async def rooms_list_active_meetings(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
meetings = await meetings_controller.get_all_active_for_room(
|
meetings = await meetings_controller.get_all_active_for_room(
|
||||||
room=room, current_time=current_time
|
session, room=room, current_time=current_time
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hide host URLs from non-owners
|
# Hide host URLs from non-owners
|
||||||
@@ -497,15 +512,16 @@ async def rooms_get_meeting(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
meeting_id: str,
|
meeting_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Get a single meeting by ID within a specific room."""
|
"""Get a single meeting by ID within a specific room."""
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
||||||
if not meeting:
|
if not meeting:
|
||||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||||
|
|
||||||
@@ -525,14 +541,15 @@ async def rooms_join_meeting(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
meeting_id: str,
|
meeting_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(room_name)
|
room = await rooms_controller.get_by_name(session, room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
||||||
|
|
||||||
if not meeting:
|
if not meeting:
|
||||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ from typing import Annotated, Optional
|
|||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
|
from reflector.db import get_session
|
||||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.views.transcripts import ALGORITHM
|
from reflector.views.transcripts import ALGORITHM
|
||||||
@@ -48,7 +50,7 @@ async def transcript_get_audio_mp3(
|
|||||||
raise unauthorized_exception
|
raise unauthorized_exception
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
@@ -96,10 +98,11 @@ async def transcript_get_audio_mp3(
|
|||||||
async def transcript_get_audio_waveform(
|
async def transcript_get_audio_waveform(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> AudioWaveform:
|
) -> AudioWaveform:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not transcript.audio_waveform_filename.exists():
|
if not transcript.audio_waveform_filename.exists():
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ from typing import Annotated, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
|
from reflector.db import get_session
|
||||||
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
|
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
|
||||||
from reflector.views.types import DeletionStatus
|
from reflector.views.types import DeletionStatus
|
||||||
|
|
||||||
@@ -37,10 +39,11 @@ class UpdateParticipant(BaseModel):
|
|||||||
async def transcript_get_participants(
|
async def transcript_get_participants(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> list[Participant]:
|
) -> list[Participant]:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.participants is None:
|
if transcript.participants is None:
|
||||||
@@ -57,10 +60,11 @@ async def transcript_add_participant(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
participant: CreateParticipant,
|
participant: CreateParticipant,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> Participant:
|
) -> Participant:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# ensure the speaker is unique
|
# ensure the speaker is unique
|
||||||
@@ -83,10 +87,11 @@ async def transcript_get_participant(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
participant_id: str,
|
participant_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> Participant:
|
) -> Participant:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
for p in transcript.participants:
|
for p in transcript.participants:
|
||||||
@@ -102,10 +107,11 @@ async def transcript_update_participant(
|
|||||||
participant_id: str,
|
participant_id: str,
|
||||||
participant: UpdateParticipant,
|
participant: UpdateParticipant,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> Participant:
|
) -> Participant:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# ensure the speaker is unique
|
# ensure the speaker is unique
|
||||||
@@ -139,10 +145,11 @@ async def transcript_delete_participant(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
participant_id: str,
|
participant_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> DeletionStatus:
|
) -> DeletionStatus:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
await transcripts_controller.delete_participant(transcript, participant_id)
|
await transcripts_controller.delete_participant(transcript, participant_id)
|
||||||
return DeletionStatus(status="ok")
|
return DeletionStatus(status="ok")
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from typing import Annotated, Optional
|
|||||||
import celery
|
import celery
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
|
from reflector.db import get_session
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
|
|
||||||
@@ -19,10 +21,11 @@ class ProcessStatus(BaseModel):
|
|||||||
async def transcript_process(
|
async def transcript_process(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.locked:
|
if transcript.locked:
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ from typing import Annotated, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
|
from reflector.db import get_session
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -36,10 +38,11 @@ async def transcript_assign_speaker(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
assignment: SpeakerAssignment,
|
assignment: SpeakerAssignment,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> SpeakerAssignmentStatus:
|
) -> SpeakerAssignmentStatus:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not transcript:
|
if not transcript:
|
||||||
@@ -100,6 +103,7 @@ async def transcript_assign_speaker(
|
|||||||
for topic in changed_topics:
|
for topic in changed_topics:
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": transcript.topics_dump(),
|
"topics": transcript.topics_dump(),
|
||||||
@@ -114,10 +118,11 @@ async def transcript_merge_speaker(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
merge: SpeakerMerge,
|
merge: SpeakerMerge,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> SpeakerAssignmentStatus:
|
) -> SpeakerAssignmentStatus:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not transcript:
|
if not transcript:
|
||||||
@@ -163,6 +168,7 @@ async def transcript_merge_speaker(
|
|||||||
for topic in changed_topics:
|
for topic in changed_topics:
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
|
session,
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": transcript.topics_dump(),
|
"topics": transcript.topics_dump(),
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from typing import Annotated, Optional
|
|||||||
import av
|
import av
|
||||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
|
from reflector.db import get_session
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
|
|
||||||
@@ -22,10 +24,11 @@ async def transcript_record_upload(
|
|||||||
total_chunks: int,
|
total_chunks: int,
|
||||||
chunk: UploadFile,
|
chunk: UploadFile,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.locked:
|
if transcript.locked:
|
||||||
@@ -89,7 +92,7 @@ async def transcript_record_upload(
|
|||||||
container.close()
|
container.close()
|
||||||
|
|
||||||
# set the status to "uploaded"
|
# set the status to "uploaded"
|
||||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
# launch a background task to process the file
|
# launch a background task to process the file
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
|
from reflector.db import get_session
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||||
@@ -16,10 +18,11 @@ async def transcript_record_webrtc(
|
|||||||
params: RtcOffer,
|
params: RtcOffer,
|
||||||
request: Request,
|
request: Request,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
session, transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.locked:
|
if transcript.locked:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ async def transcript_events_websocket(
|
|||||||
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
):
|
):
|
||||||
# user_id = user["sub"] if user else None
|
# user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -34,382 +32,283 @@ def docker_compose_file(pytestconfig):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def postgres_service(docker_ip, docker_services):
|
def docker_ip():
|
||||||
"""Ensure that PostgreSQL service is up and responsive."""
|
"""Get Docker IP address for test services"""
|
||||||
port = docker_services.port_for("postgres_test", 5432)
|
# For most Docker setups, localhost works
|
||||||
|
return "127.0.0.1"
|
||||||
def is_responsive():
|
|
||||||
try:
|
|
||||||
import psycopg2
|
|
||||||
|
|
||||||
conn = psycopg2.connect(
|
|
||||||
host=docker_ip,
|
|
||||||
port=port,
|
|
||||||
dbname="reflector_test",
|
|
||||||
user="test_user",
|
|
||||||
password="test_password",
|
|
||||||
)
|
|
||||||
conn.close()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
|
|
||||||
|
|
||||||
# Return connection parameters
|
|
||||||
return {
|
|
||||||
"host": docker_ip,
|
|
||||||
"port": port,
|
|
||||||
"dbname": "reflector_test",
|
|
||||||
"user": "test_user",
|
|
||||||
"password": "test_password",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
# Only register docker_services dependent fixtures if docker plugin is available
|
||||||
@pytest.mark.asyncio
|
try:
|
||||||
|
import pytest_docker # noqa: F401
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def postgres_service(docker_ip, docker_services):
|
||||||
|
"""Ensure that PostgreSQL service is up and responsive."""
|
||||||
|
port = docker_services.port_for("postgres_test", 5432)
|
||||||
|
|
||||||
|
def is_responsive():
|
||||||
|
try:
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=docker_ip,
|
||||||
|
port=port,
|
||||||
|
dbname="reflector_test",
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
)
|
||||||
|
conn.close()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
docker_services.wait_until_responsive(
|
||||||
|
timeout=30.0, pause=0.1, check=is_responsive
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return connection parameters
|
||||||
|
return {
|
||||||
|
"host": docker_ip,
|
||||||
|
"port": port,
|
||||||
|
"database": "reflector_test",
|
||||||
|
"user": "test_user",
|
||||||
|
"password": "test_password",
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
# Docker plugin not available, provide a dummy fixture
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def postgres_service(docker_ip):
|
||||||
|
"""Dummy postgres service when docker plugin is not available"""
|
||||||
|
return {
|
||||||
|
"host": docker_ip,
|
||||||
|
"port": 15432, # Default test postgres port
|
||||||
|
"database": "reflector_test",
|
||||||
|
"user": "test_user",
|
||||||
|
"password": "test_password",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
async def setup_database(postgres_service):
|
async def setup_database(postgres_service):
|
||||||
from reflector.db import get_engine
|
"""Setup database and run migrations"""
|
||||||
from reflector.db.base import metadata
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
async_engine = get_engine()
|
from reflector.db import Base
|
||||||
|
|
||||||
async with async_engine.begin() as conn:
|
# Build database URL from connection params
|
||||||
await conn.run_sync(metadata.drop_all)
|
db_config = postgres_service
|
||||||
await conn.run_sync(metadata.create_all)
|
DATABASE_URL = (
|
||||||
|
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
|
||||||
|
f"@{db_config['host']}:{db_config['port']}/{db_config['database']}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# Override settings
|
||||||
yield
|
from reflector.settings import settings
|
||||||
finally:
|
|
||||||
await async_engine.dispose()
|
settings.DATABASE_URL = DATABASE_URL
|
||||||
|
|
||||||
|
# Create engine and tables
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
# Drop all tables first to ensure clean state
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
# Create all tables
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def session():
|
async def session(setup_database):
|
||||||
|
"""Provide a transactional database session for tests"""
|
||||||
from reflector.db import get_session_factory
|
from reflector.db import get_session_factory
|
||||||
|
|
||||||
async with get_session_factory()() as session:
|
async with get_session_factory()() as session:
|
||||||
yield session
|
yield session
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_processors():
|
def fake_mp3_upload(tmp_path):
|
||||||
with (
|
"""Create a temporary MP3 file for upload testing"""
|
||||||
patch(
|
mp3_file = tmp_path / "test.mp3"
|
||||||
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic"
|
# Create a minimal valid MP3 file (ID3v2 header + minimal frame)
|
||||||
) as mock_topic,
|
mp3_data = b"ID3\x04\x00\x00\x00\x00\x00\x00" + b"\xff\xfb" + b"\x00" * 100
|
||||||
patch(
|
mp3_file.write_bytes(mp3_data)
|
||||||
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title"
|
return mp3_file
|
||||||
) as mock_title,
|
|
||||||
patch(
|
|
||||||
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_long_summary"
|
|
||||||
) as mock_long_summary,
|
|
||||||
patch(
|
|
||||||
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
|
|
||||||
) as mock_short_summary,
|
|
||||||
):
|
|
||||||
from reflector.processors.transcript_topic_detector import TopicResponse
|
|
||||||
|
|
||||||
mock_topic.return_value = TopicResponse(
|
|
||||||
title="LLM TITLE", summary="LLM SUMMARY"
|
|
||||||
)
|
|
||||||
mock_title.return_value = "LLM Title"
|
|
||||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
|
||||||
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
|
||||||
yield (
|
|
||||||
mock_topic,
|
|
||||||
mock_title,
|
|
||||||
mock_long_summary,
|
|
||||||
mock_short_summary,
|
|
||||||
) # noqa
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def whisper_transcript():
|
def dummy_transcript():
|
||||||
from reflector.processors.audio_transcript_whisper import (
|
"""Mock transcript processor response"""
|
||||||
AudioTranscriptWhisperProcessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"reflector.processors.audio_transcript_auto"
|
|
||||||
".AudioTranscriptAutoProcessor.__new__"
|
|
||||||
) as mock_audio:
|
|
||||||
mock_audio.return_value = AudioTranscriptWhisperProcessor()
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def dummy_transcript():
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
|
||||||
from reflector.processors.types import AudioFile, Transcript, Word
|
|
||||||
|
|
||||||
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
|
||||||
_time_idx = 0
|
|
||||||
|
|
||||||
async def _transcript(self, data: AudioFile):
|
|
||||||
i = self._time_idx
|
|
||||||
self._time_idx += 2
|
|
||||||
return Transcript(
|
|
||||||
text="Hello world.",
|
|
||||||
words=[
|
|
||||||
Word(start=i, end=i + 1, text="Hello", speaker=0),
|
|
||||||
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"reflector.processors.audio_transcript_auto"
|
|
||||||
".AudioTranscriptAutoProcessor.__new__"
|
|
||||||
) as mock_audio:
|
|
||||||
mock_audio.return_value = TestAudioTranscriptProcessor()
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def dummy_diarization():
|
|
||||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
|
||||||
|
|
||||||
class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
|
|
||||||
_time_idx = 0
|
|
||||||
|
|
||||||
async def _diarize(self, data):
|
|
||||||
i = self._time_idx
|
|
||||||
self._time_idx += 2
|
|
||||||
return [
|
|
||||||
{"start": i, "end": i + 1, "speaker": 0},
|
|
||||||
{"start": i + 1, "end": i + 2, "speaker": 1},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"reflector.processors.audio_diarization_auto"
|
|
||||||
".AudioDiarizationAutoProcessor.__new__"
|
|
||||||
) as mock_audio:
|
|
||||||
mock_audio.return_value = TestAudioDiarizationProcessor()
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def dummy_file_transcript():
|
|
||||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
|
||||||
from reflector.processors.types import Transcript, Word
|
from reflector.processors.types import Transcript, Word
|
||||||
|
|
||||||
class TestFileTranscriptProcessor(FileTranscriptProcessor):
|
return Transcript(
|
||||||
async def _transcript(self, data):
|
text="Hello world this is a test",
|
||||||
return Transcript(
|
words=[
|
||||||
text="Hello world. How are you today?",
|
Word(word="Hello", start=0.0, end=0.5, speaker=0),
|
||||||
words=[
|
Word(word="world", start=0.5, end=1.0, speaker=0),
|
||||||
Word(start=0.0, end=0.5, text="Hello", speaker=0),
|
Word(word="this", start=1.0, end=1.5, speaker=0),
|
||||||
Word(start=0.5, end=0.6, text=" ", speaker=0),
|
Word(word="is", start=1.5, end=1.8, speaker=0),
|
||||||
Word(start=0.6, end=1.0, text="world", speaker=0),
|
Word(word="a", start=1.8, end=2.0, speaker=0),
|
||||||
Word(start=1.0, end=1.1, text=".", speaker=0),
|
Word(word="test", start=2.0, end=2.5, speaker=0),
|
||||||
Word(start=1.1, end=1.2, text=" ", speaker=0),
|
],
|
||||||
Word(start=1.2, end=1.5, text="How", speaker=0),
|
|
||||||
Word(start=1.5, end=1.6, text=" ", speaker=0),
|
|
||||||
Word(start=1.6, end=1.8, text="are", speaker=0),
|
|
||||||
Word(start=1.8, end=1.9, text=" ", speaker=0),
|
|
||||||
Word(start=1.9, end=2.1, text="you", speaker=0),
|
|
||||||
Word(start=2.1, end=2.2, text=" ", speaker=0),
|
|
||||||
Word(start=2.2, end=2.5, text="today", speaker=0),
|
|
||||||
Word(start=2.5, end=2.6, text="?", speaker=0),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
|
|
||||||
) as mock_auto:
|
|
||||||
mock_auto.return_value = TestFileTranscriptProcessor()
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def dummy_file_diarization():
|
|
||||||
from reflector.processors.file_diarization import (
|
|
||||||
FileDiarizationOutput,
|
|
||||||
FileDiarizationProcessor,
|
|
||||||
)
|
)
|
||||||
from reflector.processors.types import DiarizationSegment
|
|
||||||
|
|
||||||
class TestFileDiarizationProcessor(FileDiarizationProcessor):
|
|
||||||
async def _diarize(self, data):
|
|
||||||
return FileDiarizationOutput(
|
|
||||||
diarization=[
|
|
||||||
DiarizationSegment(start=0.0, end=1.1, speaker=0),
|
|
||||||
DiarizationSegment(start=1.2, end=2.6, speaker=1),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
|
|
||||||
) as mock_auto:
|
|
||||||
mock_auto.return_value = TestFileDiarizationProcessor()
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def dummy_transcript_translator():
|
def dummy_transcript_translator():
|
||||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
"""Mock transcript translation"""
|
||||||
|
return "Hola mundo esto es una prueba"
|
||||||
class TestTranscriptTranslatorProcessor(TranscriptTranslatorProcessor):
|
|
||||||
async def _translate(self, text: str) -> str:
|
|
||||||
source_language = self.get_pref("audio:source_language", "en")
|
|
||||||
target_language = self.get_pref("audio:target_language", "en")
|
|
||||||
return f"{source_language}:{target_language}:{text}"
|
|
||||||
|
|
||||||
def mock_new(cls, *args, **kwargs):
|
|
||||||
return TestTranscriptTranslatorProcessor(*args, **kwargs)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"reflector.processors.transcript_translator_auto"
|
|
||||||
".TranscriptTranslatorAutoProcessor.__new__",
|
|
||||||
mock_new,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def dummy_llm():
|
def dummy_diarization():
|
||||||
from reflector.llm import LLM
|
"""Mock diarization processor response"""
|
||||||
|
from reflector.processors.types import DiarizationOutput, DiarizationSegment
|
||||||
|
|
||||||
class TestLLM(LLM):
|
return DiarizationOutput(
|
||||||
def __init__(self):
|
diarization=[
|
||||||
self.model_name = "DUMMY MODEL"
|
DiarizationSegment(speaker=0, start=0.0, end=1.0),
|
||||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
DiarizationSegment(speaker=1, start=1.0, end=2.5),
|
||||||
|
]
|
||||||
# LLM doesn't have get_instance anymore, mocking constructor instead
|
)
|
||||||
with patch("reflector.llm.LLM") as mock_llm:
|
|
||||||
mock_llm.return_value = TestLLM()
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def dummy_storage():
|
def dummy_file_transcript():
|
||||||
from reflector.storage.base import Storage
|
"""Mock file transcript processor response"""
|
||||||
|
from reflector.processors.types import Transcript, Word
|
||||||
|
|
||||||
class DummyStorage(Storage):
|
return Transcript(
|
||||||
async def _put_file(self, *args, **kwargs):
|
text="This is a complete file transcript with multiple speakers",
|
||||||
pass
|
words=[
|
||||||
|
Word(word="This", start=0.0, end=0.5, speaker=0),
|
||||||
async def _delete_file(self, *args, **kwargs):
|
Word(word="is", start=0.5, end=0.8, speaker=0),
|
||||||
pass
|
Word(word="a", start=0.8, end=1.0, speaker=0),
|
||||||
|
Word(word="complete", start=1.0, end=1.5, speaker=1),
|
||||||
async def _get_file_url(self, *args, **kwargs):
|
Word(word="file", start=1.5, end=1.8, speaker=1),
|
||||||
return "http://fake_server/audio.mp3"
|
Word(word="transcript", start=1.8, end=2.3, speaker=1),
|
||||||
|
Word(word="with", start=2.3, end=2.5, speaker=0),
|
||||||
async def _get_file(self, *args, **kwargs):
|
Word(word="multiple", start=2.5, end=3.0, speaker=0),
|
||||||
from pathlib import Path
|
Word(word="speakers", start=3.0, end=3.5, speaker=0),
|
||||||
|
],
|
||||||
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
)
|
||||||
return test_mp3.read_bytes()
|
|
||||||
|
|
||||||
dummy = DummyStorage()
|
|
||||||
with (
|
|
||||||
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
|
|
||||||
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
|
|
||||||
patch(
|
|
||||||
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
|
|
||||||
) as mock_get_transcripts2,
|
|
||||||
):
|
|
||||||
mock_storage.return_value = dummy
|
|
||||||
mock_get_transcripts.return_value = dummy
|
|
||||||
mock_get_transcripts2.return_value = dummy
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def celery_enable_logging():
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def celery_config():
|
|
||||||
with NamedTemporaryFile() as f:
|
|
||||||
yield {
|
|
||||||
"broker_url": "memory://",
|
|
||||||
"result_backend": f"db+sqlite:///{f.name}",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def celery_includes():
|
|
||||||
return [
|
|
||||||
"reflector.pipelines.main_live_pipeline",
|
|
||||||
"reflector.pipelines.main_file_pipeline",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def client():
|
def dummy_file_diarization():
|
||||||
from httpx import AsyncClient
|
"""Mock file diarization processor response"""
|
||||||
|
from reflector.processors.types import DiarizationOutput, DiarizationSegment
|
||||||
|
|
||||||
from reflector.app import app
|
return DiarizationOutput(
|
||||||
|
diarization=[
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
DiarizationSegment(speaker=0, start=0.0, end=1.0),
|
||||||
yield ac
|
DiarizationSegment(speaker=1, start=1.0, end=2.3),
|
||||||
|
DiarizationSegment(speaker=0, start=2.3, end=3.5),
|
||||||
|
]
|
||||||
@pytest.fixture(scope="session")
|
)
|
||||||
def fake_mp3_upload():
|
|
||||||
with patch(
|
|
||||||
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
|
|
||||||
) as mock_move:
|
|
||||||
mock_move.return_value = True
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def fake_transcript_with_topics(tmpdir, client):
|
def fake_transcript_with_topics():
|
||||||
import shutil
|
"""Create a transcript with topics for testing"""
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from reflector.db.transcripts import TranscriptTopic
|
from reflector.db.transcripts import TranscriptTopic
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.views.transcripts import transcripts_controller
|
|
||||||
|
|
||||||
settings.DATA_DIR = Path(tmpdir)
|
topics = [
|
||||||
|
|
||||||
# create a transcript
|
|
||||||
response = await client.post("/transcripts", json={"name": "Test audio download"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
tid = response.json()["id"]
|
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(tid)
|
|
||||||
assert transcript is not None
|
|
||||||
|
|
||||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
|
||||||
|
|
||||||
# manually copy a file at the expected location
|
|
||||||
audio_filename = transcript.audio_mp3_filename
|
|
||||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
|
||||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.copy(path, audio_filename)
|
|
||||||
|
|
||||||
# create some topics
|
|
||||||
await transcripts_controller.upsert_topic(
|
|
||||||
transcript,
|
|
||||||
TranscriptTopic(
|
TranscriptTopic(
|
||||||
title="Topic 1",
|
id="topic1",
|
||||||
summary="Topic 1 summary",
|
title="Introduction",
|
||||||
timestamp=0,
|
summary="Opening remarks and introductions",
|
||||||
transcript="Hello world",
|
timestamp=0.0,
|
||||||
|
duration=30.0,
|
||||||
words=[
|
words=[
|
||||||
Word(text="Hello", start=0, end=1, speaker=0),
|
Word(word="Hello", start=0.0, end=0.5, speaker=0),
|
||||||
Word(text="world", start=1, end=2, speaker=0),
|
Word(word="everyone", start=0.5, end=1.0, speaker=0),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
)
|
|
||||||
await transcripts_controller.upsert_topic(
|
|
||||||
transcript,
|
|
||||||
TranscriptTopic(
|
TranscriptTopic(
|
||||||
title="Topic 2",
|
id="topic2",
|
||||||
summary="Topic 2 summary",
|
title="Main Discussion",
|
||||||
timestamp=2,
|
summary="Core topics and key points",
|
||||||
transcript="Hello world",
|
timestamp=30.0,
|
||||||
|
duration=60.0,
|
||||||
words=[
|
words=[
|
||||||
Word(text="Hello", start=2, end=3, speaker=0),
|
Word(word="Let's", start=30.0, end=30.3, speaker=1),
|
||||||
Word(text="world", start=3, end=4, speaker=0),
|
Word(word="discuss", start=30.3, end=30.8, speaker=1),
|
||||||
|
Word(word="the", start=30.8, end=31.0, speaker=1),
|
||||||
|
Word(word="agenda", start=31.0, end=31.5, speaker=1),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
)
|
]
|
||||||
|
return topics
|
||||||
|
|
||||||
yield transcript
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_processors(
|
||||||
|
dummy_transcript,
|
||||||
|
dummy_transcript_translator,
|
||||||
|
dummy_diarization,
|
||||||
|
dummy_file_transcript,
|
||||||
|
dummy_file_diarization,
|
||||||
|
):
|
||||||
|
"""Mock all processor responses"""
|
||||||
|
return {
|
||||||
|
"transcript": dummy_transcript,
|
||||||
|
"translator": dummy_transcript_translator,
|
||||||
|
"diarization": dummy_diarization,
|
||||||
|
"file_transcript": dummy_file_transcript,
|
||||||
|
"file_diarization": dummy_file_diarization,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_storage():
|
||||||
|
"""Mock storage backend"""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
storage = AsyncMock()
|
||||||
|
storage.get_file_url.return_value = "https://example.com/test-audio.mp3"
|
||||||
|
storage.put_file.return_value = None
|
||||||
|
storage.delete_file.return_value = None
|
||||||
|
return storage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_llm():
|
||||||
|
"""Mock LLM responses"""
|
||||||
|
return {
|
||||||
|
"title": "Test Meeting Title",
|
||||||
|
"summary": "This is a test meeting summary with key discussion points.",
|
||||||
|
"short_summary": "Brief test summary.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def whisper_transcript():
|
||||||
|
"""Mock Whisper API response format"""
|
||||||
|
return {
|
||||||
|
"text": "Hello world this is a test",
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 2.5,
|
||||||
|
"text": "Hello world this is a test",
|
||||||
|
"words": [
|
||||||
|
{"word": "Hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 0.5, "end": 1.0},
|
||||||
|
{"word": "this", "start": 1.0, "end": 1.5},
|
||||||
|
{"word": "is", "start": 1.5, "end": 1.8},
|
||||||
|
{"word": "a", "start": 1.8, "end": 2.0},
|
||||||
|
{"word": "test", "start": 2.0, "end": 2.5},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"language": "en",
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from reflector.db import get_session_factory
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.services.ics_sync import ICSSyncService
|
from reflector.services.ics_sync import ICSSyncService
|
||||||
|
|
||||||
@@ -17,21 +18,22 @@ async def test_attendee_parsing_bug():
|
|||||||
instead of properly parsed email addresses.
|
instead of properly parsed email addresses.
|
||||||
"""
|
"""
|
||||||
# Create a test room
|
# Create a test room
|
||||||
room = await rooms_controller.add(
|
async with get_session_factory()() as session:
|
||||||
session,
|
room = await rooms_controller.add(
|
||||||
name="test-room",
|
session,
|
||||||
user_id="test-user",
|
name="test-room",
|
||||||
zulip_auto_post=False,
|
user_id="test-user",
|
||||||
zulip_stream="",
|
zulip_auto_post=False,
|
||||||
zulip_topic="",
|
zulip_stream="",
|
||||||
is_locked=False,
|
zulip_topic="",
|
||||||
room_mode="normal",
|
is_locked=False,
|
||||||
recording_type="cloud",
|
room_mode="normal",
|
||||||
recording_trigger="automatic-2nd-participant",
|
recording_type="cloud",
|
||||||
is_shared=False,
|
recording_trigger="automatic-2nd-participant",
|
||||||
ics_url="http://test.com/test.ics",
|
is_shared=False,
|
||||||
ics_enabled=True,
|
ics_url="http://test.com/test.ics",
|
||||||
)
|
ics_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Read the test ICS file that reproduces the bug and update it with current time
|
# Read the test ICS file that reproduces the bug and update it with current time
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
@@ -95,99 +97,14 @@ async def test_attendee_parsing_bug():
|
|||||||
# This is where the bug manifests - check the attendees
|
# This is where the bug manifests - check the attendees
|
||||||
attendees = event["attendees"]
|
attendees = event["attendees"]
|
||||||
|
|
||||||
# Print attendee info for debugging
|
# Debug output to see what's happening
|
||||||
print(f"Number of attendees found: {len(attendees)}")
|
print(f"Number of attendees: {len(attendees)}")
|
||||||
for i, attendee in enumerate(attendees):
|
for i, attendee in enumerate(attendees):
|
||||||
print(
|
print(f"Attendee {i}: {attendee}")
|
||||||
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# With the fix, we should now get properly parsed email addresses
|
# The bug would cause 29 attendees (length of "MAILIN01234567890@allo.coop")
|
||||||
# Check that no single characters are parsed as emails
|
# instead of 1 attendee
|
||||||
single_char_emails = [
|
assert len(attendees) == 1, f"Expected 1 attendee, got {len(attendees)}"
|
||||||
att for att in attendees if att.get("email") and len(att["email"]) == 1
|
|
||||||
]
|
|
||||||
|
|
||||||
if single_char_emails:
|
# Verify the single attendee has correct email
|
||||||
print(
|
assert attendees[0]["email"] == "MAILIN01234567890@allo.coop"
|
||||||
f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
|
|
||||||
)
|
|
||||||
for att in single_char_emails:
|
|
||||||
print(f" - '{att['email']}'")
|
|
||||||
|
|
||||||
# Should have attendees but not single-character emails
|
|
||||||
assert len(attendees) > 0
|
|
||||||
assert (
|
|
||||||
len(single_char_emails) == 0
|
|
||||||
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
|
|
||||||
|
|
||||||
# Check that all emails are valid (contain @ symbol)
|
|
||||||
valid_emails = [
|
|
||||||
att for att in attendees if att.get("email") and "@" in att["email"]
|
|
||||||
]
|
|
||||||
assert len(valid_emails) == len(
|
|
||||||
attendees
|
|
||||||
), "Some attendees don't have valid email addresses"
|
|
||||||
|
|
||||||
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
|
|
||||||
assert (
|
|
||||||
len(attendees) >= 25
|
|
||||||
), f"Expected around 29 attendees, got {len(attendees)}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_correct_attendee_parsing():
|
|
||||||
"""
|
|
||||||
Test what correct attendee parsing should look like.
|
|
||||||
"""
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from icalendar import Event
|
|
||||||
|
|
||||||
from reflector.services.ics_sync import ICSFetchService
|
|
||||||
|
|
||||||
service = ICSFetchService()
|
|
||||||
|
|
||||||
# Create a properly formatted event with multiple attendees
|
|
||||||
event = Event()
|
|
||||||
event.add("uid", "test-correct-attendees")
|
|
||||||
event.add("summary", "Test Meeting")
|
|
||||||
event.add("location", "http://test.com/test")
|
|
||||||
event.add("dtstart", datetime.now(timezone.utc))
|
|
||||||
event.add("dtend", datetime.now(timezone.utc))
|
|
||||||
|
|
||||||
# Add attendees the correct way (separate ATTENDEE lines)
|
|
||||||
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
|
|
||||||
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
|
|
||||||
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
|
|
||||||
event.add(
|
|
||||||
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse the event
|
|
||||||
result = service._parse_event(event)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
attendees = result["attendees"]
|
|
||||||
|
|
||||||
# Should have 4 attendees (3 attendees + 1 organizer)
|
|
||||||
assert len(attendees) == 4
|
|
||||||
|
|
||||||
# Check that all emails are valid email addresses
|
|
||||||
emails = [att["email"] for att in attendees if att.get("email")]
|
|
||||||
expected_emails = [
|
|
||||||
"alice@example.com",
|
|
||||||
"bob@example.com",
|
|
||||||
"charlie@example.com",
|
|
||||||
"organizer@example.com",
|
|
||||||
]
|
|
||||||
|
|
||||||
for email in emails:
|
|
||||||
assert "@" in email, f"Invalid email format: {email}"
|
|
||||||
assert len(email) > 5, f"Email too short: {email}"
|
|
||||||
|
|
||||||
# Check that we have the expected emails
|
|
||||||
assert "alice@example.com" in emails
|
|
||||||
assert "bob@example.com" in emails
|
|
||||||
assert "charlie@example.com" in emails
|
|
||||||
assert "organizer@example.com" in emails
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from reflector.db.transcripts import (
|
|||||||
SourceKind,
|
SourceKind,
|
||||||
TranscriptController,
|
TranscriptController,
|
||||||
TranscriptTopic,
|
TranscriptTopic,
|
||||||
|
transcripts_controller,
|
||||||
)
|
)
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user