import enum import json import os import shutil from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Sequence if TYPE_CHECKING: from reflector.ws_events import TranscriptEventName import sqlalchemy from fastapi import HTTPException from pydantic import BaseModel, ConfigDict, Field, field_serializer from sqlalchemy import Enum from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.sql import false, or_ from reflector.db import get_database, metadata from reflector.db.recordings import recordings_controller from reflector.db.rooms import rooms from reflector.db.utils import is_postgresql from reflector.logger import logger from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings from reflector.storage import get_transcripts_storage from reflector.utils import generate_uuid4 from reflector.utils.webvtt import topics_to_webvtt class SourceKind(enum.StrEnum): ROOM = enum.auto() LIVE = enum.auto() FILE = enum.auto() transcript_change_seq = sqlalchemy.Sequence("transcript_change_seq", metadata=metadata) transcripts = sqlalchemy.Table( "transcript", metadata, sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), sqlalchemy.Column("name", sqlalchemy.String), sqlalchemy.Column("status", sqlalchemy.String), sqlalchemy.Column("locked", sqlalchemy.Boolean), sqlalchemy.Column("duration", sqlalchemy.Float), sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)), sqlalchemy.Column("title", sqlalchemy.String), sqlalchemy.Column("short_summary", sqlalchemy.String), sqlalchemy.Column("long_summary", sqlalchemy.String), sqlalchemy.Column("action_items", sqlalchemy.JSON), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("participants", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String), sqlalchemy.Column("target_language", sqlalchemy.String), sqlalchemy.Column( "reviewed", sqlalchemy.Boolean, nullable=False, server_default=false() ), sqlalchemy.Column( "audio_location", sqlalchemy.String, nullable=False, server_default="local", ), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), sqlalchemy.Column( "share_mode", sqlalchemy.String, nullable=False, server_default="private", ), sqlalchemy.Column( "meeting_id", sqlalchemy.String, ), sqlalchemy.Column("recording_id", sqlalchemy.String), sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer), sqlalchemy.Column( "source_kind", Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]), nullable=False, ), # indicative field: whether associated audio is deleted # the main "audio deleted" is the presence of the audio itself / consents not-given # same field could've been in recording/meeting, and it's maybe even ok to dupe it at need sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), sqlalchemy.Column("room_id", sqlalchemy.String), sqlalchemy.Column("webvtt", sqlalchemy.Text), # Hatchet workflow run ID for resumption of failed workflows sqlalchemy.Column("workflow_run_id", sqlalchemy.String), sqlalchemy.Column("deleted_at", sqlalchemy.DateTime(timezone=True), nullable=True), sqlalchemy.Column( "change_seq", sqlalchemy.BigInteger, transcript_change_seq, server_default=transcript_change_seq.next_value(), ), sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), sqlalchemy.Index("idx_transcript_user_id", "user_id"), sqlalchemy.Index("idx_transcript_created_at", "created_at"), sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"), sqlalchemy.Index("idx_transcript_room_id", "room_id"), sqlalchemy.Index("idx_transcript_source_kind", "source_kind"), sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"), ) # Add PostgreSQL-specific full-text search column # This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py if is_postgresql(): transcripts.append_column( sqlalchemy.Column( "search_vector_en", TSVECTOR, sqlalchemy.Computed( "setweight(to_tsvector('english', coalesce(title, '')), 'A') || " "setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || " "setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')", persisted=True, ), ) ) # Add GIN index for the search vector transcripts.append_constraint( sqlalchemy.Index( "idx_transcript_search_vector_en", "search_vector_en", postgresql_using="gin", ) ) def generate_transcript_name() -> str: now = datetime.now(timezone.utc) return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" TranscriptStatus = Literal[ "idle", "uploaded", "recording", "processing", "error", "ended" ] class StrValue(BaseModel): value: str class AudioWaveform(BaseModel): data: list[float] class TranscriptText(BaseModel): text: str translation: str | None class TranscriptSegmentTopic(BaseModel): speaker: int text: str timestamp: float class TranscriptTopic(BaseModel): id: str = Field(default_factory=generate_uuid4) title: str summary: str timestamp: float duration: float | None = 0 transcript: str | None = None words: list[ProcessorWord] = [] class TranscriptFinalShortSummary(BaseModel): short_summary: str class TranscriptFinalLongSummary(BaseModel): long_summary: str class TranscriptActionItems(BaseModel): action_items: dict class TranscriptFinalTitle(BaseModel): title: str class TranscriptDuration(BaseModel): duration: float class TranscriptWaveform(BaseModel): waveform: Sequence[float] class TranscriptEvent(BaseModel): event: str # Typed at call sites via ws_events.TranscriptEventName; str here for DB compat data: dict class TranscriptParticipant(BaseModel): model_config = ConfigDict(from_attributes=True) id: str = Field(default_factory=generate_uuid4) speaker: int | None name: str user_id: str | None = None class Transcript(BaseModel): """Full transcript model with all fields.""" id: str = Field(default_factory=generate_uuid4) user_id: str | None = None name: str = Field(default_factory=generate_transcript_name) status: TranscriptStatus = "idle" duration: float = 0 created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) title: str | None = None source_kind: SourceKind room_id: str | None = None locked: bool = False short_summary: str | None = None long_summary: str | None = None action_items: dict | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] participants: list[TranscriptParticipant] | None = [] source_language: str = "en" target_language: str = "en" share_mode: Literal["private", "semi-private", "public"] = "private" audio_location: str = "local" reviewed: bool = False meeting_id: str | None = None recording_id: str | None = None zulip_message_id: int | None = None audio_deleted: bool | None = None webvtt: str | None = None workflow_run_id: str | None = None # Hatchet workflow run ID for resumption change_seq: int | None = None deleted_at: datetime | None = None @field_serializer("created_at", when_used="json") def serialize_datetime(self, dt: datetime) -> str: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt.isoformat() def add_event( self, event: "TranscriptEventName", data: BaseModel ) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) self.events.append(ev) return ev def upsert_topic(self, topic: TranscriptTopic): index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None) if index is not None: self.topics[index] = topic else: self.topics.append(topic) def upsert_participant(self, participant: TranscriptParticipant): if self.participants: index = next( (i for i, p in enumerate(self.participants) if p.id == participant.id), None, ) if index is not None: self.participants[index] = participant else: self.participants.append(participant) else: self.participants = [participant] return participant def delete_participant(self, participant_id: str): index = next( (i for i, p in enumerate(self.participants) if p.id == participant_id), None, ) if index is not None: del self.participants[index] def events_dump(self, mode="json"): return [event.model_dump(mode=mode) for event in self.events] def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] def participants_dump(self, mode="json"): return [participant.model_dump(mode=mode) for participant in self.participants] def unlink(self): if os.path.exists(self.data_path) and os.path.isdir(self.data_path): shutil.rmtree(self.data_path) @property def data_path(self): return Path(settings.DATA_DIR) / self.id @property def audio_wav_filename(self): return self.data_path / "audio.wav" @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @property def audio_waveform_filename(self): return self.data_path / "audio.json" @property def storage_audio_path(self): return f"{self.id}/audio.mp3" @property def audio_waveform(self): try: with open(self.audio_waveform_filename) as fd: data = json.load(fd) except json.JSONDecodeError: # unlink file if it's corrupted self.audio_waveform_filename.unlink(missing_ok=True) return None return AudioWaveform(data=data) async def get_audio_url(self) -> str: if self.audio_location == "local": return self._generate_local_audio_link() elif self.audio_location == "storage": return await self._generate_storage_audio_link() raise Exception(f"Unknown audio location {self.audio_location}") async def _generate_storage_audio_link(self) -> str: return await get_transcripts_storage().get_file_url(self.storage_audio_path) def _generate_local_audio_link(self) -> str: # we need to create an url to be used for diarization # we can't use the audio_mp3_filename because it's not accessible # from the diarization processor # TODO don't import app in db from reflector.app import app # noqa: PLC0415 # TODO a util + don''t import views in db from reflector.views.transcripts import create_access_token # noqa: PLC0415 path = app.url_path_for( "transcript_get_audio_mp3", transcript_id=self.id, ) url = f"{settings.BASE_URL}{path}" if self.user_id: # we pass token only if the user_id is set # otherwise, the audio is public token = create_access_token( {"sub": self.user_id}, expires_delta=timedelta(minutes=15), ) url += f"?token={token}" return url def find_empty_speaker(self) -> int: """ Find an empty speaker seat """ speakers = set( word.speaker for topic in self.topics for word in topic.words if word.speaker is not None ) i = 0 while True: if i not in speakers: return i i += 1 raise Exception("No empty speaker found") class TranscriptController: async def get_all( self, user_id: str | None = None, order_by: str | None = None, filter_empty: bool | None = False, filter_recording: bool | None = False, source_kind: SourceKind | None = None, room_id: str | None = None, search_term: str | None = None, change_seq_from: int | None = None, return_query: bool = False, exclude_columns: list[str] = [ "topics", "events", "participants", "action_items", ], ) -> list[Transcript]: """ Get all transcripts If `user_id` is specified, only return transcripts that belong to the user. Otherwise, return all anonymous transcripts. Parameters: - `order_by`: field to order by, e.g. "-created_at" - `filter_empty`: filter out empty transcripts - `filter_recording`: filter out transcripts that are currently recording - `room_id`: filter transcripts by room ID - `search_term`: filter transcripts by search term - `change_seq_from`: filter transcripts with change_seq > this value """ query = transcripts.select().join( rooms, transcripts.c.room_id == rooms.c.id, isouter=True ) query = query.where(transcripts.c.deleted_at.is_(None)) if user_id: query = query.where( or_(transcripts.c.user_id == user_id, rooms.c.is_shared) ) elif not settings.PUBLIC_MODE: query = query.where(rooms.c.is_shared) if source_kind: query = query.where(transcripts.c.source_kind == source_kind) if room_id: query = query.where(transcripts.c.room_id == room_id) if search_term: query = query.where(transcripts.c.title.ilike(f"%{search_term}%")) if change_seq_from is not None: query = query.where(transcripts.c.change_seq > change_seq_from) # Exclude heavy JSON columns from list queries transcript_columns = [ col for col in transcripts.c if col.name not in exclude_columns ] query = query.with_only_columns( transcript_columns + [ rooms.c.name.label("room_name"), ] ) if order_by is not None: if order_by.startswith("-"): field = getattr(transcripts.c, order_by[1:]).desc() else: field = getattr(transcripts.c, order_by) query = query.order_by(field) if filter_empty: query = query.filter(transcripts.c.status != "idle") if filter_recording: query = query.filter(transcripts.c.status != "recording") # print(query.compile(compile_kwargs={"literal_binds": True})) if return_query: return query results = await get_database().fetch_all(query) return results async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: """ Get a transcript by id """ query = transcripts.select().where(transcripts.c.id == transcript_id) if "user_id" in kwargs: query = query.where(transcripts.c.user_id == kwargs["user_id"]) result = await get_database().fetch_one(query) if not result: return None return Transcript(**result) async def get_by_recording_id( self, recording_id: str, **kwargs ) -> Transcript | None: """ Get a transcript by recording_id """ query = transcripts.select().where(transcripts.c.recording_id == recording_id) if "user_id" in kwargs: query = query.where(transcripts.c.user_id == kwargs["user_id"]) result = await get_database().fetch_one(query) if not result: return None return Transcript(**result) async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]: """ Get transcripts by room_id (direct access without joins) """ query = transcripts.select().where( transcripts.c.room_id == room_id, transcripts.c.deleted_at.is_(None), ) if "user_id" in kwargs: query = query.where(transcripts.c.user_id == kwargs["user_id"]) if "order_by" in kwargs: order_by = kwargs["order_by"] field = getattr(transcripts.c, order_by[1:]) if order_by.startswith("-"): field = field.desc() query = query.order_by(field) results = await get_database().fetch_all(query) return [Transcript(**result) for result in results] async def get_by_id_for_http( self, transcript_id: str, user_id: str | None, ) -> Transcript: """ Get a transcript by ID for HTTP request. If not found, it will raise a 404 error. If the user is not allowed to access the transcript, it will raise a 403 error. This method checks the share mode of the transcript and the user_id to determine if the user can access the transcript. """ query = transcripts.select().where(transcripts.c.id == transcript_id) result = await get_database().fetch_one(query) if not result: raise HTTPException(status_code=404, detail="Transcript not found") transcript = Transcript(**result) if transcript.deleted_at is not None: raise HTTPException(status_code=404, detail="Transcript not found") # if the transcript is anonymous, share mode is not checked if transcript.user_id is None: return transcript if transcript.share_mode == "private": # in private mode, only the owner can access the transcript if transcript.user_id == user_id: return transcript elif transcript.share_mode == "semi-private": # in semi-private mode, only the owner and the users with the link # can access the transcript if user_id is not None: return transcript elif transcript.share_mode == "public": # in public mode, everyone can access the transcript return transcript raise HTTPException(status_code=403, detail="Transcript access denied") async def add( self, name: str, source_kind: SourceKind, source_language: str = "en", target_language: str = "en", user_id: str | None = None, recording_id: str | None = None, share_mode: str = "private", meeting_id: str | None = None, room_id: str | None = None, ): """ Add a new transcript """ transcript = Transcript( name=name, source_kind=source_kind, source_language=source_language, target_language=target_language, user_id=user_id, recording_id=recording_id, share_mode=share_mode, meeting_id=meeting_id, room_id=room_id, ) query = transcripts.insert().values(**transcript.model_dump()) await get_database().execute(query) return transcript # TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates. # using mutate=True is discouraged async def update( self, transcript: Transcript, values: dict, mutate=False ) -> Transcript: """ Update a transcript fields with key/values in values. Returns a copy of the transcript with updated values. """ values = TranscriptController._handle_topics_update(values) query = ( transcripts.update() .where(transcripts.c.id == transcript.id) .values(**values) ) await get_database().execute(query) if mutate: for key, value in values.items(): setattr(transcript, key, value) updated_transcript = transcript.model_copy(update=values) return updated_transcript @staticmethod def _handle_topics_update(values: dict) -> dict: """Auto-update WebVTT when topics are updated.""" if values.get("webvtt") is not None: logger.warn("trying to update read-only webvtt column") pass topics_data = values.get("topics") if topics_data is None: return values return { **values, "webvtt": topics_to_webvtt( [TranscriptTopic(**topic_dict) for topic_dict in topics_data] ), } async def remove_by_id( self, transcript_id: str, user_id: str | None = None, ) -> None: """ Soft-delete a transcript by id. Sets deleted_at on the transcript and its associated recording. All files (S3 and local) are preserved for later retrieval. """ transcript = await self.get_by_id(transcript_id) if not transcript: return if user_id is not None and transcript.user_id != user_id: return if transcript.deleted_at is not None: return now = datetime.now(timezone.utc) # Soft-delete the associated recording (keeps S3 files intact) if transcript.recording_id: try: await recordings_controller.remove_by_id(transcript.recording_id) except Exception as e: logger.warning( "Failed to soft-delete recording", exc_info=e, recording_id=transcript.recording_id, ) # Soft-delete the transcript (keeps all files intact) query = ( transcripts.update() .where(transcripts.c.id == transcript_id) .values(deleted_at=now) ) await get_database().execute(query) async def remove_by_recording_id(self, recording_id: str): """ Soft-delete a transcript by recording_id """ query = ( transcripts.update() .where(transcripts.c.recording_id == recording_id) .values(deleted_at=datetime.now(timezone.utc)) ) await get_database().execute(query) @staticmethod def user_can_mutate(transcript: Transcript, user_id: str | None) -> bool: """ Returns True if the given user is allowed to modify the transcript. Policy: - Anonymous transcripts (user_id is None) cannot be modified via API - Only the owner (matching user_id) can modify their transcript """ if transcript.user_id is None: return False return user_id and transcript.user_id == user_id @staticmethod def check_can_mutate(transcript: Transcript, user_id: str | None) -> None: """ Raises HTTP 403 if the user cannot mutate the transcript. Policy: - Anonymous transcripts (user_id is None) are editable by anyone - Owned transcripts can only be mutated by their owner """ if transcript.user_id is not None and transcript.user_id != user_id: raise HTTPException(status_code=403, detail="Not authorized") @asynccontextmanager async def transaction(self): """ A context manager for database transaction """ async with get_database().transaction(isolation="serializable"): yield async def append_event( self, transcript: Transcript, event: "TranscriptEventName", data: Any, ) -> TranscriptEvent: """ Append an event to a transcript """ resp = transcript.add_event(event=event, data=data) await self.update(transcript, {"events": transcript.events_dump()}) return resp async def upsert_topic( self, transcript: Transcript, topic: TranscriptTopic, ) -> TranscriptEvent: """ Upsert topics to a transcript """ transcript.upsert_topic(topic) await self.update(transcript, {"topics": transcript.topics_dump()}) async def move_mp3_to_storage(self, transcript: Transcript): """ Move mp3 file to storage """ if transcript.audio_deleted: raise FileNotFoundError( f"Invalid state of transcript {transcript.id}: audio_deleted mark is set true" ) if transcript.audio_location == "local": # store the audio on external storage if it's not already there if not transcript.audio_mp3_filename.exists(): raise FileNotFoundError( f"Audio file not found: {transcript.audio_mp3_filename}" ) await get_transcripts_storage().put_file( transcript.storage_audio_path, transcript.audio_mp3_filename.read_bytes(), ) # indicate on the transcript that the audio is now on storage # mutates transcript argument await self.update(transcript, {"audio_location": "storage"}, mutate=True) # unlink the local file transcript.audio_mp3_filename.unlink(missing_ok=True) async def download_mp3_from_storage(self, transcript: Transcript): """ Download audio from storage """ storage = get_transcripts_storage() try: with open(transcript.audio_mp3_filename, "wb") as f: await storage.stream_to_fileobj(transcript.storage_audio_path, f) except Exception: transcript.audio_mp3_filename.unlink(missing_ok=True) raise async def upsert_participant( self, transcript: Transcript, participant: TranscriptParticipant, ) -> TranscriptParticipant: """ Add/update a participant to a transcript """ result = transcript.upsert_participant(participant) await self.update(transcript, {"participants": transcript.participants_dump()}) return result async def delete_participant( self, transcript: Transcript, participant_id: str, ): """ Delete a participant from a transcript """ transcript.delete_participant(participant_id) await self.update(transcript, {"participants": transcript.participants_dump()}) async def set_status( self, transcript_id: str, status: TranscriptStatus ) -> TranscriptEvent | None: """ Update the status of a transcript Will add an event STATUS + update the status field of transcript """ async with self.transaction(): transcript = await self.get_by_id(transcript_id) if not transcript: raise Exception(f"Transcript {transcript_id} not found") if transcript.status == status: return resp = await self.append_event( transcript=transcript, event="STATUS", data=StrValue(value=status), ) await self.update(transcript, {"status": status}) return resp transcripts_controller = TranscriptController()