import enum import json import os import shutil from contextlib import asynccontextmanager from datetime import datetime, timezone from pathlib import Path from typing import Any, Literal import sqlalchemy from fastapi import HTTPException from pydantic import BaseModel, ConfigDict, Field, field_serializer from sqlalchemy import Enum from sqlalchemy.sql import false, or_ from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings from reflector.storage import get_transcripts_storage from reflector.utils import generate_uuid4 class SourceKind(enum.StrEnum): ROOM = enum.auto() LIVE = enum.auto() FILE = enum.auto() transcripts = sqlalchemy.Table( "transcript", metadata, sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), sqlalchemy.Column("name", sqlalchemy.String), sqlalchemy.Column("status", sqlalchemy.String), sqlalchemy.Column("locked", sqlalchemy.Boolean), sqlalchemy.Column("duration", sqlalchemy.Float), sqlalchemy.Column("created_at", sqlalchemy.DateTime), sqlalchemy.Column("title", sqlalchemy.String), sqlalchemy.Column("short_summary", sqlalchemy.String), sqlalchemy.Column("long_summary", sqlalchemy.String), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("participants", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String), sqlalchemy.Column("target_language", sqlalchemy.String), sqlalchemy.Column( "reviewed", sqlalchemy.Boolean, nullable=False, server_default=false() ), sqlalchemy.Column( "audio_location", sqlalchemy.String, nullable=False, server_default="local", ), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), sqlalchemy.Column( "share_mode", sqlalchemy.String, nullable=False, server_default="private", ), sqlalchemy.Column( "meeting_id", sqlalchemy.String, ), sqlalchemy.Column("recording_id", sqlalchemy.String), sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer), sqlalchemy.Column( "source_kind", Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]), nullable=False, ), # indicative field: whether associated audio is deleted # the main "audio deleted" is the presence of the audio itself / consents not-given # same field could've been in recording/meeting, and it's maybe even ok to dupe it at need sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), sqlalchemy.Column("room_id", sqlalchemy.String), sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), sqlalchemy.Index("idx_transcript_user_id", "user_id"), sqlalchemy.Index("idx_transcript_created_at", "created_at"), sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"), sqlalchemy.Index("idx_transcript_room_id", "room_id"), ) def generate_transcript_name() -> str: now = datetime.now(timezone.utc) return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" class AudioWaveform(BaseModel): data: list[float] class TranscriptText(BaseModel): text: str translation: str | None class TranscriptSegmentTopic(BaseModel): speaker: int text: str timestamp: float class TranscriptTopic(BaseModel): id: str = Field(default_factory=generate_uuid4) title: str summary: str timestamp: float duration: float | None = 0 transcript: str | None = None words: list[ProcessorWord] = [] class TranscriptFinalShortSummary(BaseModel): short_summary: str class TranscriptFinalLongSummary(BaseModel): long_summary: str class TranscriptFinalTitle(BaseModel): title: str class TranscriptDuration(BaseModel): duration: float class TranscriptWaveform(BaseModel): waveform: list[float] class TranscriptEvent(BaseModel): event: str data: dict class TranscriptParticipant(BaseModel): model_config = ConfigDict(from_attributes=True) id: str = Field(default_factory=generate_uuid4) speaker: int | None name: str class Transcript(BaseModel): id: str = Field(default_factory=generate_uuid4) user_id: str | None = None name: str = Field(default_factory=generate_transcript_name) status: str = "idle" locked: bool = False duration: float = 0 created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) title: str | None = None short_summary: str | None = None long_summary: str | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] participants: list[TranscriptParticipant] | None = [] source_language: str = "en" target_language: str = "en" share_mode: Literal["private", "semi-private", "public"] = "private" audio_location: str = "local" reviewed: bool = False meeting_id: str | None = None recording_id: str | None = None zulip_message_id: int | None = None source_kind: SourceKind audio_deleted: bool | None = None room_id: str | None = None @field_serializer("created_at", when_used="json") def serialize_datetime(self, dt: datetime) -> str: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt.isoformat() def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) self.events.append(ev) return ev def upsert_topic(self, topic: TranscriptTopic): index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None) if index is not None: self.topics[index] = topic else: self.topics.append(topic) def upsert_participant(self, participant: TranscriptParticipant): if self.participants: index = next( (i for i, p in enumerate(self.participants) if p.id == participant.id), None, ) if index is not None: self.participants[index] = participant else: self.participants.append(participant) else: self.participants = [participant] return participant def delete_participant(self, participant_id: str): index = next( (i for i, p in enumerate(self.participants) if p.id == participant_id), None, ) if index is not None: del self.participants[index] def events_dump(self, mode="json"): return [event.model_dump(mode=mode) for event in self.events] def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] def participants_dump(self, mode="json"): return [participant.model_dump(mode=mode) for participant in self.participants] def unlink(self): if os.path.exists(self.data_path) and os.path.isdir(self.data_path): shutil.rmtree(self.data_path) @property def data_path(self): return Path(settings.DATA_DIR) / self.id @property def audio_wav_filename(self): return self.data_path / "audio.wav" @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @property def audio_waveform_filename(self): return self.data_path / "audio.json" @property def storage_audio_path(self): return f"{self.id}/audio.mp3" @property def audio_waveform(self): try: with open(self.audio_waveform_filename) as fd: data = json.load(fd) except json.JSONDecodeError: # unlink file if it's corrupted self.audio_waveform_filename.unlink(missing_ok=True) return None return AudioWaveform(data=data) async def get_audio_url(self) -> str: if self.audio_location == "local": return self._generate_local_audio_link() elif self.audio_location == "storage": return await self._generate_storage_audio_link() raise Exception(f"Unknown audio location {self.audio_location}") async def _generate_storage_audio_link(self) -> str: return await get_transcripts_storage().get_file_url(self.storage_audio_path) def _generate_local_audio_link(self) -> str: # we need to create an url to be used for diarization # we can't use the audio_mp3_filename because it's not accessible # from the diarization processor from datetime import timedelta from reflector.app import app from reflector.views.transcripts import create_access_token path = app.url_path_for( "transcript_get_audio_mp3", transcript_id=self.id, ) url = f"{settings.BASE_URL}{path}" if self.user_id: # we pass token only if the user_id is set # otherwise, the audio is public token = create_access_token( {"sub": self.user_id}, expires_delta=timedelta(minutes=15), ) url += f"?token={token}" return url def find_empty_speaker(self) -> int: """ Find an empty speaker seat """ speakers = set( word.speaker for topic in self.topics for word in topic.words if word.speaker is not None ) i = 0 while True: if i not in speakers: return i i += 1 raise Exception("No empty speaker found") class TranscriptController: async def get_all( self, user_id: str | None = None, order_by: str | None = None, filter_empty: bool | None = False, filter_recording: bool | None = False, source_kind: SourceKind | None = None, room_id: str | None = None, search_term: str | None = None, return_query: bool = False, exclude_columns: list[str] = ["topics", "events", "participants"], ) -> list[Transcript]: """ Get all transcripts If `user_id` is specified, only return transcripts that belong to the user. Otherwise, return all anonymous transcripts. Parameters: - `order_by`: field to order by, e.g. "-created_at" - `filter_empty`: filter out empty transcripts - `filter_recording`: filter out transcripts that are currently recording - `room_id`: filter transcripts by room ID - `search_term`: filter transcripts by search term """ from reflector.db.rooms import rooms query = transcripts.select().join( rooms, transcripts.c.room_id == rooms.c.id, isouter=True ) if user_id: query = query.where( or_(transcripts.c.user_id == user_id, rooms.c.is_shared) ) else: query = query.where(rooms.c.is_shared) if source_kind: query = query.where(transcripts.c.source_kind == source_kind) if room_id: query = query.where(transcripts.c.room_id == room_id) if search_term: query = query.where(transcripts.c.title.ilike(f"%{search_term}%")) # Exclude heavy JSON columns from list queries transcript_columns = [ col for col in transcripts.c if col.name not in exclude_columns ] query = query.with_only_columns( transcript_columns + [ rooms.c.name.label("room_name"), ] ) if order_by is not None: field = getattr(transcripts.c, order_by[1:]) if order_by.startswith("-"): field = field.desc() query = query.order_by(field) if filter_empty: query = query.filter(transcripts.c.status != "idle") if filter_recording: query = query.filter(transcripts.c.status != "recording") # print(query.compile(compile_kwargs={"literal_binds": True})) if return_query: return query results = await database.fetch_all(query) return results async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: """ Get a transcript by id """ query = transcripts.select().where(transcripts.c.id == transcript_id) if "user_id" in kwargs: query = query.where(transcripts.c.user_id == kwargs["user_id"]) result = await database.fetch_one(query) if not result: return None return Transcript(**result) async def get_by_recording_id( self, recording_id: str, **kwargs ) -> Transcript | None: """ Get a transcript by recording_id """ query = transcripts.select().where(transcripts.c.recording_id == recording_id) if "user_id" in kwargs: query = query.where(transcripts.c.user_id == kwargs["user_id"]) result = await database.fetch_one(query) if not result: return None return Transcript(**result) async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]: """ Get transcripts by room_id (direct access without joins) """ query = transcripts.select().where(transcripts.c.room_id == room_id) if "user_id" in kwargs: query = query.where(transcripts.c.user_id == kwargs["user_id"]) if "order_by" in kwargs: order_by = kwargs["order_by"] field = getattr(transcripts.c, order_by[1:]) if order_by.startswith("-"): field = field.desc() query = query.order_by(field) results = await database.fetch_all(query) return [Transcript(**result) for result in results] async def get_by_id_for_http( self, transcript_id: str, user_id: str | None, ) -> Transcript: """ Get a transcript by ID for HTTP request. If not found, it will raise a 404 error. If the user is not allowed to access the transcript, it will raise a 403 error. This method checks the share mode of the transcript and the user_id to determine if the user can access the transcript. """ query = transcripts.select().where(transcripts.c.id == transcript_id) result = await database.fetch_one(query) if not result: raise HTTPException(status_code=404, detail="Transcript not found") # if the transcript is anonymous, share mode is not checked transcript = Transcript(**result) if transcript.user_id is None: return transcript if transcript.share_mode == "private": # in private mode, only the owner can access the transcript if transcript.user_id == user_id: return transcript elif transcript.share_mode == "semi-private": # in semi-private mode, only the owner and the users with the link # can access the transcript if user_id is not None: return transcript elif transcript.share_mode == "public": # in public mode, everyone can access the transcript return transcript raise HTTPException(status_code=403, detail="Transcript access denied") async def add( self, name: str, source_kind: SourceKind, source_language: str = "en", target_language: str = "en", user_id: str | None = None, recording_id: str | None = None, share_mode: str = "private", meeting_id: str | None = None, room_id: str | None = None, ): """ Add a new transcript """ transcript = Transcript( name=name, source_kind=source_kind, source_language=source_language, target_language=target_language, user_id=user_id, recording_id=recording_id, share_mode=share_mode, meeting_id=meeting_id, room_id=room_id, ) query = transcripts.insert().values(**transcript.model_dump()) await database.execute(query) return transcript async def update(self, transcript: Transcript, values: dict, mutate=True): """ Update a transcript fields with key/values in values """ query = ( transcripts.update() .where(transcripts.c.id == transcript.id) .values(**values) ) await database.execute(query) if mutate: for key, value in values.items(): setattr(transcript, key, value) async def remove_by_id( self, transcript_id: str, user_id: str | None = None, ) -> None: """ Remove a transcript by id """ transcript = await self.get_by_id(transcript_id) if not transcript: return if user_id is not None and transcript.user_id != user_id: return transcript.unlink() query = transcripts.delete().where(transcripts.c.id == transcript_id) await database.execute(query) async def remove_by_recording_id(self, recording_id: str): """ Remove a transcript by recording_id """ query = transcripts.delete().where(transcripts.c.recording_id == recording_id) await database.execute(query) @asynccontextmanager async def transaction(self): """ A context manager for database transaction """ async with database.transaction(isolation="serializable"): yield async def append_event( self, transcript: Transcript, event: str, data: Any, ) -> TranscriptEvent: """ Append an event to a transcript """ resp = transcript.add_event(event=event, data=data) await self.update( transcript, {"events": transcript.events_dump()}, mutate=False, ) return resp async def upsert_topic( self, transcript: Transcript, topic: TranscriptTopic, ) -> TranscriptEvent: """ Upsert topics to a transcript """ transcript.upsert_topic(topic) await self.update( transcript, {"topics": transcript.topics_dump()}, mutate=False, ) async def move_mp3_to_storage(self, transcript: Transcript): """ Move mp3 file to storage """ if transcript.audio_deleted: raise FileNotFoundError( f"Invalid state of transcript {transcript.id}: audio_deleted mark is set true" ) if transcript.audio_location == "local": # store the audio on external storage if it's not already there if not transcript.audio_mp3_filename.exists(): raise FileNotFoundError( f"Audio file not found: {transcript.audio_mp3_filename}" ) await get_transcripts_storage().put_file( transcript.storage_audio_path, transcript.audio_mp3_filename.read_bytes(), ) # indicate on the transcript that the audio is now on storage await self.update(transcript, {"audio_location": "storage"}) # unlink the local file transcript.audio_mp3_filename.unlink(missing_ok=True) async def download_mp3_from_storage(self, transcript: Transcript): """ Download audio from storage """ transcript.audio_mp3_filename.write_bytes( await get_transcripts_storage().get_file( transcript.storage_audio_path, ) ) async def upsert_participant( self, transcript: Transcript, participant: TranscriptParticipant, ) -> TranscriptParticipant: """ Add/update a participant to a transcript """ result = transcript.upsert_participant(participant) await self.update( transcript, {"participants": transcript.participants_dump()}, mutate=False, ) return result async def delete_participant( self, transcript: Transcript, participant_id: str, ): """ Delete a participant from a transcript """ transcript.delete_participant(participant_id) await self.update( transcript, {"participants": transcript.participants_dump()}, mutate=False, ) transcripts_controller = TranscriptController()