Files
reflector/server/reflector/db/transcripts.py
Mathieu Virbel e0c71c5548 refactor: migrate to SQLAlchemy 2.0 ORM-style patterns
- Replace __table__.join() with ORM-style joins using select_from().outerjoin()
- Replace __table__.delete() with delete(Model) in tests
- Migrate from **row.__dict__ to model_validate() with ConfigDict(from_attributes=True)
- Add ConfigDict(from_attributes=True) to all Pydantic models for proper SQLAlchemy model conversion
- Update all controller methods to use model_validate() instead of dict unpacking

This completes the migration to SQLAlchemy 2.0 recommended patterns while maintaining
backwards compatibility and improving code consistency.
2025-09-23 16:46:37 -06:00

719 lines
23 KiB
Python

import enum
import json
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import or_
from reflector.db.base import RoomModel, TranscriptModel
from reflector.db.recordings import recordings_controller
from reflector.logger import logger
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.storage import get_recordings_storage, get_transcripts_storage
from reflector.utils import generate_uuid4
from reflector.utils.webvtt import topics_to_webvtt
class SourceKind(enum.StrEnum):
ROOM = enum.auto()
LIVE = enum.auto()
FILE = enum.auto()
def generate_transcript_name() -> str:
now = datetime.now(timezone.utc)
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
TranscriptStatus = Literal[
"idle", "uploaded", "recording", "processing", "error", "ended"
]
class StrValue(BaseModel):
value: str
class AudioWaveform(BaseModel):
data: list[float]
class TranscriptText(BaseModel):
text: str
translation: str | None
class TranscriptSegmentTopic(BaseModel):
speaker: int
text: str
timestamp: float
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
timestamp: float
duration: float | None = 0
transcript: str | None = None
words: list[ProcessorWord] = []
class TranscriptFinalShortSummary(BaseModel):
short_summary: str
class TranscriptFinalLongSummary(BaseModel):
long_summary: str
class TranscriptFinalTitle(BaseModel):
title: str
class TranscriptDuration(BaseModel):
duration: float
class TranscriptWaveform(BaseModel):
waveform: list[float]
class TranscriptEvent(BaseModel):
event: str
data: dict
class TranscriptParticipant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
speaker: int | None
name: str
class Transcript(BaseModel):
"""Full transcript model with all fields."""
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: TranscriptStatus = "idle"
duration: float = 0
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
locked: bool = False
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
participants: list[TranscriptParticipant] | None = []
source_language: str = "en"
target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private"
audio_location: str = "local"
reviewed: bool = False
meeting_id: str | None = None
recording_id: str | None = None
zulip_message_id: int | None = None
audio_deleted: bool | None = None
webvtt: str | None = None
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat()
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None)
if index is not None:
self.topics[index] = topic
else:
self.topics.append(topic)
def upsert_participant(self, participant: TranscriptParticipant):
if self.participants:
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant.id),
None,
)
if index is not None:
self.participants[index] = participant
else:
self.participants.append(participant)
else:
self.participants = [participant]
return participant
def delete_participant(self, participant_id: str):
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant_id),
None,
)
if index is not None:
del self.participants[index]
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def participants_dump(self, mode="json"):
return [participant.model_dump(mode=mode) for participant in self.participants]
def unlink(self):
if os.path.exists(self.data_path) and os.path.isdir(self.data_path):
shutil.rmtree(self.data_path)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_wav_filename(self):
return self.data_path / "audio.wav"
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def storage_audio_path(self):
return f"{self.id}/audio.mp3"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
async def get_audio_url(self) -> str:
if self.audio_location == "local":
return self._generate_local_audio_link()
elif self.audio_location == "storage":
return await self._generate_storage_audio_link()
raise Exception(f"Unknown audio location {self.audio_location}")
async def _generate_storage_audio_link(self) -> str:
return await get_transcripts_storage().get_file_url(self.storage_audio_path)
def _generate_local_audio_link(self) -> str:
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
# TODO don't import app in db
from reflector.app import app # noqa: PLC0415
# TODO a util + don''t import views in db
from reflector.views.transcripts import create_access_token # noqa: PLC0415
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=self.id,
)
url = f"{settings.BASE_URL}{path}"
if self.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": self.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
return url
def find_empty_speaker(self) -> int:
"""
Find an empty speaker seat
"""
speakers = set(
word.speaker
for topic in self.topics
for word in topic.words
if word.speaker is not None
)
i = 0
while True:
if i not in speakers:
return i
i += 1
raise Exception("No empty speaker found")
class TranscriptController:
async def get_all(
self,
session: AsyncSession,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
source_kind: SourceKind | None = None,
room_id: str | None = None,
search_term: str | None = None,
return_query: bool = False,
exclude_columns: list[str] = ["topics", "events", "participants"],
) -> list[Transcript]:
"""
Get all transcripts
If `user_id` is specified, only return transcripts that belong to the user.
Otherwise, return all anonymous transcripts.
Parameters:
- `order_by`: field to order by, e.g. "-created_at"
- `filter_empty`: filter out empty transcripts
- `filter_recording`: filter out transcripts that are currently recording
- `room_id`: filter transcripts by room ID
- `search_term`: filter transcripts by search term
"""
query = select(TranscriptModel).join(
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True
)
if user_id:
query = query.where(
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared)
)
else:
query = query.where(RoomModel.is_shared)
if source_kind:
query = query.where(TranscriptModel.source_kind == source_kind)
if room_id:
query = query.where(TranscriptModel.room_id == room_id)
if search_term:
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
# Exclude heavy JSON columns from list queries
# Get all ORM column attributes except excluded ones
transcript_columns = [
getattr(TranscriptModel, col.name)
for col in TranscriptModel.__table__.c
if col.name not in exclude_columns
]
query = query.with_only_columns(
*transcript_columns,
RoomModel.name.label("room_name"),
)
if order_by is not None:
field = getattr(TranscriptModel, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(TranscriptModel.status != "idle")
if filter_recording:
query = query.filter(TranscriptModel.status != "recording")
# print(query.compile(compile_kwargs={"literal_binds": True}))
if return_query:
return query
result = await session.execute(query)
return [dict(row) for row in result.mappings().all()]
async def get_by_id(
self, session: AsyncSession, transcript_id: str, **kwargs
) -> Transcript | None:
"""
Get a transcript by id
"""
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return Transcript.model_validate(row)
async def get_by_recording_id(
self, session: AsyncSession, recording_id: str, **kwargs
) -> Transcript | None:
"""
Get a transcript by recording_id
"""
query = select(TranscriptModel).where(
TranscriptModel.recording_id == recording_id
)
if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return Transcript.model_validate(row)
async def get_by_room_id(
self, session: AsyncSession, room_id: str, **kwargs
) -> list[Transcript]:
"""
Get transcripts by room_id (direct access without joins)
"""
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id)
if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
if "order_by" in kwargs:
order_by = kwargs["order_by"]
field = getattr(TranscriptModel, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
results = await session.execute(query)
return [
Transcript.model_validate(dict(row)) for row in results.mappings().all()
]
async def get_by_id_for_http(
self,
session: AsyncSession,
transcript_id: str,
user_id: str | None,
) -> Transcript:
"""
Get a transcript by ID for HTTP request.
If not found, it will raise a 404 error.
If the user is not allowed to access the transcript, it will raise a 403 error.
This method checks the share mode of the transcript and the user_id
to determine if the user can access the transcript.
"""
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
raise HTTPException(status_code=404, detail="Transcript not found")
# if the transcript is anonymous, share mode is not checked
transcript = Transcript.model_validate(row)
if transcript.user_id is None:
return transcript
if transcript.share_mode == "private":
# in private mode, only the owner can access the transcript
if transcript.user_id == user_id:
return transcript
elif transcript.share_mode == "semi-private":
# in semi-private mode, only the owner and the users with the link
# can access the transcript
if user_id is not None:
return transcript
elif transcript.share_mode == "public":
# in public mode, everyone can access the transcript
return transcript
raise HTTPException(status_code=403, detail="Transcript access denied")
async def add(
self,
session: AsyncSession,
name: str,
source_kind: SourceKind,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
recording_id: str | None = None,
share_mode: str = "private",
meeting_id: str | None = None,
room_id: str | None = None,
):
"""
Add a new transcript
"""
transcript = Transcript(
name=name,
source_kind=source_kind,
source_language=source_language,
target_language=target_language,
user_id=user_id,
recording_id=recording_id,
share_mode=share_mode,
meeting_id=meeting_id,
room_id=room_id,
)
query = insert(TranscriptModel).values(**transcript.model_dump())
await session.execute(query)
await session.commit()
return transcript
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False
) -> Transcript:
"""
Update a transcript fields with key/values in values.
Returns a copy of the transcript with updated values.
"""
values = TranscriptController._handle_topics_update(values)
query = (
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
.values(**values)
)
await session.execute(query)
await session.commit()
if mutate:
for key, value in values.items():
setattr(transcript, key, value)
updated_transcript = transcript.model_copy(update=values)
return updated_transcript
@staticmethod
def _handle_topics_update(values: dict) -> dict:
"""Auto-update WebVTT when topics are updated."""
if values.get("webvtt") is not None:
logger.warn("trying to update read-only webvtt column")
pass
topics_data = values.get("topics")
if topics_data is None:
return values
return {
**values,
"webvtt": topics_to_webvtt(
[TranscriptTopic(**topic_dict) for topic_dict in topics_data]
),
}
async def remove_by_id(
self,
session: AsyncSession,
transcript_id: str,
user_id: str | None = None,
) -> None:
"""
Remove a transcript by id
"""
transcript = await self.get_by_id(session, transcript_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
if transcript.audio_location == "storage" and not transcript.audio_deleted:
try:
await get_transcripts_storage().delete_file(
transcript.storage_audio_path
)
except Exception as e:
logger.warning(
"Failed to delete transcript audio from storage",
exc_info=e,
transcript_id=transcript.id,
)
transcript.unlink()
if transcript.recording_id:
try:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
if recording:
try:
await get_recordings_storage().delete_file(recording.object_key)
except Exception as e:
logger.warning(
"Failed to delete recording object from S3",
exc_info=e,
recording_id=transcript.recording_id,
)
await recordings_controller.remove_by_id(
session, transcript.recording_id
)
except Exception as e:
logger.warning(
"Failed to delete recording row",
exc_info=e,
recording_id=transcript.recording_id,
)
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id)
await session.execute(query)
await session.commit()
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str):
"""
Remove a transcript by recording_id
"""
query = delete(TranscriptModel).where(
TranscriptModel.recording_id == recording_id
)
await session.execute(query)
await session.commit()
@asynccontextmanager
async def transaction(self, session: AsyncSession):
"""
A context manager for database transaction
"""
async with session.begin():
yield
async def append_event(
self,
session: AsyncSession,
transcript: Transcript,
event: str,
data: Any,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(session, transcript, {"events": transcript.events_dump()})
return resp
async def upsert_topic(
self,
session: AsyncSession,
transcript: Transcript,
topic: TranscriptTopic,
) -> TranscriptEvent:
"""
Upsert topics to a transcript
"""
transcript.upsert_topic(topic)
await self.update(session, transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript):
"""
Move mp3 file to storage
"""
if transcript.audio_deleted:
raise FileNotFoundError(
f"Invalid state of transcript {transcript.id}: audio_deleted mark is set true"
)
if transcript.audio_location == "local":
# store the audio on external storage if it's not already there
if not transcript.audio_mp3_filename.exists():
raise FileNotFoundError(
f"Audio file not found: {transcript.audio_mp3_filename}"
)
await get_transcripts_storage().put_file(
transcript.storage_audio_path,
transcript.audio_mp3_filename.read_bytes(),
)
# indicate on the transcript that the audio is now on storage
# mutates transcript argument
await self.update(
session, transcript, {"audio_location": "storage"}, mutate=True
)
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
async def download_mp3_from_storage(
self, session: AsyncSession, transcript: Transcript
):
"""
Download audio from storage
"""
transcript.audio_mp3_filename.write_bytes(
await get_transcripts_storage().get_file(
transcript.storage_audio_path,
)
)
async def upsert_participant(
self,
session: AsyncSession,
transcript: Transcript,
participant: TranscriptParticipant,
) -> TranscriptParticipant:
"""
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(
session, transcript, {"participants": transcript.participants_dump()}
)
return result
async def delete_participant(
self,
session: AsyncSession,
transcript: Transcript,
participant_id: str,
):
"""
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(
session, transcript, {"participants": transcript.participants_dump()}
)
async def set_status(
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus
) -> TranscriptEvent | None:
"""
Update the status of a transcript
Will add an event STATUS + update the status field of transcript
"""
async with self.transaction(session):
transcript = await self.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
if transcript.status == status:
return
resp = await self.append_event(
session,
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await self.update(session, transcript, {"status": status})
return resp
transcripts_controller = TranscriptController()