feat: search backend (#537)

* docs: transient docs

* chore: cleanup

* webvtt WIP

* webvtt field

* chore: webvtt tests comments

* chore: remove useless tests

* feat: search TASK.md

* feat: full text search by title/webvtt

* chore: search api task

* feat: search api

* feat: search API

* chore: rm task md

* chore: roll back unnecessary validators

* chore: pr review WIP

* chore: pr review WIP

* chore: pr review

* chore: top imports

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* chore: lint

* chore: lint

* fix: db datetime definitions

* fix: flush() params

* fix: update transcript mutability expectation / test

* fix: update transcript mutability expectation / test

* chore: auto review

* chore: new controller extraction

* chore: new controller extraction

* chore: cleanup

* chore: review WIP

* chore: pr WIP

* chore: remove ci lint

* chore: openapi regeneration

* chore: openapi regeneration

* chore: postgres test doc

* fix: .dockerignore for arm binaries

* fix: .dockerignore for arm binaries

* fix: cap test loops

* fix: cap test loops

* fix: cap test loops

* fix: get_transcript_topics

* chore: remove flow.md docs and claude guidance

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc
This commit is contained in:
Igor Loskutov
2025-08-13 10:03:38 -04:00
committed by GitHub
parent a42ed12982
commit 6fb5cb21c2
29 changed files with 3213 additions and 1493 deletions

View File

@@ -0,0 +1,233 @@
"""Search functionality for transcripts and other entities."""
import logging
from datetime import datetime
from io import StringIO
from typing import Annotated, Any, Dict
import sqlalchemy
import webvtt
from pydantic import BaseModel, Field, constr, field_serializer
from reflector.db import database
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql
logger = logging.getLogger(__name__)
DEFAULT_SEARCH_LIMIT = 20
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
DEFAULT_SNIPPET_MAX_LENGTH = 150
DEFAULT_MAX_SNIPPETS = 3
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
SearchOffsetBase = Annotated[int, Field(ge=0)]
SearchTotalBase = Annotated[int, Field(ge=0)]
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
SearchOffset = Annotated[
SearchOffsetBase, Field(description="Number of results to skip")
]
SearchTotal = Annotated[
SearchTotalBase, Field(description="Total number of search results")
]
class SearchParameters(BaseModel):
"""Validated search parameters for full-text search."""
query_text: SearchQuery
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
offset: SearchOffset = 0
user_id: str | None = None
room_id: str | None = None
class SearchResultDB(BaseModel):
"""Intermediate model for validating raw database results."""
id: str = Field(..., min_length=1)
created_at: datetime
status: str = Field(..., min_length=1)
duration: float | None = Field(None, ge=0)
user_id: str | None = None
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
rank: float = Field(..., ge=0, le=1)
class SearchResult(BaseModel):
"""Public search result model with computed fields."""
id: str = Field(..., min_length=1)
title: str | None = None
user_id: str | None = None
room_id: str | None = None
created_at: datetime
status: str = Field(..., min_length=1)
rank: float = Field(..., ge=0, le=1)
duration: float | None = Field(..., ge=0, description="Duration in seconds")
search_snippets: list[str] = Field(
description="Text snippets around search matches"
)
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
if dt.tzinfo is None:
return dt.isoformat() + "Z"
return dt.isoformat()
class SearchController:
"""Controller for search operations across different entities."""
@staticmethod
def _extract_webvtt_text(webvtt_content: str) -> str:
"""Extract plain text from WebVTT content using webvtt library."""
if not webvtt_content:
return ""
try:
buffer = StringIO(webvtt_content)
vtt = webvtt.read_buffer(buffer)
return " ".join(caption.text for caption in vtt if caption.text)
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
return ""
except AttributeError as e:
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
return ""
@staticmethod
def _generate_snippets(
text: str,
q: SearchQuery,
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: int = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate multiple snippets around all occurrences of search term."""
if not text or not q:
return []
snippets = []
lower_text = text.lower()
search_lower = q.lower()
last_snippet_end = 0
start_pos = 0
while len(snippets) < max_snippets:
match_pos = lower_text.find(search_lower, start_pos)
if match_pos == -1:
if not snippets and search_lower.split():
first_word = search_lower.split()[0]
match_pos = lower_text.find(first_word, start_pos)
if match_pos == -1:
break
else:
break
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
snippet_end = min(
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
)
if snippet_start < last_snippet_end:
start_pos = match_pos + len(search_lower)
continue
snippet = text[snippet_start:snippet_end]
if snippet_start > 0:
snippet = "..." + snippet
if snippet_end < len(text):
snippet = snippet + "..."
snippet = snippet.strip()
if snippet:
snippets.append(snippet)
last_snippet_end = snippet_end
start_pos = match_pos + len(search_lower)
if start_pos >= len(text):
break
return snippets
@classmethod
async def search_transcripts(
cls, params: SearchParameters
) -> tuple[list[SearchResult], int]:
"""
Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count).
"""
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text
)
base_query = sqlalchemy.select(
[
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank"),
]
).where(transcripts.c.search_vector_en.op("@@")(search_query))
if params.user_id:
base_query = base_query.where(transcripts.c.user_id == params.user_id)
if params.room_id:
base_query = base_query.where(transcripts.c.room_id == params.room_id)
query = (
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
.limit(params.limit)
.offset(params.offset)
)
rs = await database.fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
total = await database.fetch_val(count_query)
def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt: str | None = r_dict.pop("webvtt", None)
db_result = SearchResultDB.model_validate(r_dict)
snippets = []
if webvtt:
plain_text = cls._extract_webvtt_text(webvtt)
snippets = cls._generate_snippets(plain_text, params.query_text)
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
results = [_process_result(r) for r in rs]
return results, total
search_controller = SearchController()

View File

@@ -1,9 +1,10 @@
import enum
import json
import logging
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
@@ -11,13 +12,19 @@ import sqlalchemy
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_
from reflector.db import database, metadata
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.storage import get_transcripts_storage
from reflector.utils import generate_uuid4
from reflector.utils.webvtt import topics_to_webvtt
logger = logging.getLogger(__name__)
class SourceKind(enum.StrEnum):
@@ -76,6 +83,7 @@ transcripts = sqlalchemy.Table(
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
@@ -83,6 +91,29 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
)
# Add PostgreSQL-specific full-text search column
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
if is_postgresql():
transcripts.append_column(
sqlalchemy.Column(
"search_vector_en",
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
persisted=True,
),
)
)
# Add GIN index for the search vector
transcripts.append_constraint(
sqlalchemy.Index(
"idx_transcript_search_vector_en",
"search_vector_en",
postgresql_using="gin",
)
)
def generate_transcript_name() -> str:
now = datetime.now(timezone.utc)
@@ -147,14 +178,18 @@ class TranscriptParticipant(BaseModel):
class Transcript(BaseModel):
"""Full transcript model with all fields."""
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
locked: bool = False
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
@@ -168,9 +203,8 @@ class Transcript(BaseModel):
meeting_id: str | None = None
recording_id: str | None = None
zulip_message_id: int | None = None
source_kind: SourceKind
audio_deleted: bool | None = None
room_id: str | None = None
webvtt: str | None = None
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
@@ -271,10 +305,12 @@ class Transcript(BaseModel):
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
# TODO don't import app in db
from reflector.app import app # noqa: PLC0415
# TODO a util + don''t import views in db
from reflector.views.transcripts import create_access_token # noqa: PLC0415
path = app.url_path_for(
"transcript_get_audio_mp3",
@@ -335,7 +371,6 @@ class TranscriptController:
- `room_id`: filter transcripts by room ID
- `search_term`: filter transcripts by search term
"""
from reflector.db.rooms import rooms
query = transcripts.select().join(
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
@@ -502,10 +537,17 @@ class TranscriptController:
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict, mutate=True):
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self, transcript: Transcript, values: dict, mutate=False
) -> Transcript:
"""
Update a transcript fields with key/values in values
Update a transcript fields with key/values in values.
Returns a copy of the transcript with updated values.
"""
values = TranscriptController._handle_topics_update(values)
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
@@ -516,6 +558,28 @@ class TranscriptController:
for key, value in values.items():
setattr(transcript, key, value)
updated_transcript = transcript.model_copy(update=values)
return updated_transcript
@staticmethod
def _handle_topics_update(values: dict) -> dict:
"""Auto-update WebVTT when topics are updated."""
if values.get("webvtt") is not None:
logger.warn("trying to update read-only webvtt column")
pass
topics_data = values.get("topics")
if topics_data is None:
return values
return {
**values,
"webvtt": topics_to_webvtt(
[TranscriptTopic(**topic_dict) for topic_dict in topics_data]
),
}
async def remove_by_id(
self,
transcript_id: str,
@@ -558,11 +622,7 @@ class TranscriptController:
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(
transcript,
{"events": transcript.events_dump()},
mutate=False,
)
await self.update(transcript, {"events": transcript.events_dump()})
return resp
async def upsert_topic(
@@ -574,11 +634,7 @@ class TranscriptController:
Upsert topics to a transcript
"""
transcript.upsert_topic(topic)
await self.update(
transcript,
{"topics": transcript.topics_dump()},
mutate=False,
)
await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
@@ -603,7 +659,8 @@ class TranscriptController:
)
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
# mutates transcript argument
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
@@ -627,11 +684,7 @@ class TranscriptController:
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
await self.update(transcript, {"participants": transcript.participants_dump()})
return result
async def delete_participant(
@@ -643,11 +696,7 @@ class TranscriptController:
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
await self.update(transcript, {"participants": transcript.participants_dump()})
transcripts_controller = TranscriptController()

View File

@@ -0,0 +1,7 @@
"""Database utility functions."""
from reflector.db import database
def is_postgresql() -> bool:
return database.url.scheme and database.url.scheme.startswith('postgresql')

View File

@@ -14,12 +14,15 @@ It is directly linked to our data model.
import asyncio
import functools
from contextlib import asynccontextmanager
from typing import Generic
import av
import boto3
from celery import chord, current_task, group, shared_task
from pydantic import BaseModel
from structlog import BoundLogger as Logger
from reflector.db import database
from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller
@@ -35,7 +38,7 @@ from reflector.db.transcripts import (
transcripts_controller,
)
from reflector.logger import logger
from reflector.pipelines.runner import PipelineRunner
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioDiarizationAutoProcessor,
@@ -69,8 +72,6 @@ def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
async def run_with_db():
from reflector.db import database
await database.connect()
try:
return await f(*args, **kwargs)
@@ -144,7 +145,7 @@ class StrValue(BaseModel):
value: str
class PipelineMainBase(PipelineRunner):
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
transcript_id: str
ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None
@@ -164,7 +165,11 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
@staticmethod
def wrap_transcript_topics(
topics: list[TranscriptTopic],
) -> list[TitleSummaryWithIdProcessorType]:
# transformation to a pipe-supported format
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
@@ -174,7 +179,7 @@ class PipelineMainBase(PipelineRunner):
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
for topic in topics
]
@asynccontextmanager
@@ -380,7 +385,7 @@ class PipelineMainLive(PipelineMainBase):
pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
"""
Diarize the audio and update topics
"""
@@ -404,11 +409,10 @@ class PipelineMainDiarization(PipelineMainBase):
pipeline.logger.info("Audio is local, skipping diarization")
return
topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput(
audio_url=audio_url,
topics=topics,
topics=self.wrap_transcript_topics(transcript.topics),
)
# as tempting to use pipeline.push, prefer to use the runner
@@ -421,7 +425,7 @@ class PipelineMainDiarization(PipelineMainBase):
return pipeline
class PipelineMainFromTopics(PipelineMainBase):
class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
"""
Pseudo class for generating a pipeline from topics
"""
@@ -443,7 +447,7 @@ class PipelineMainFromTopics(PipelineMainBase):
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics
topics = self.get_transcript_topics(transcript)
topics = PipelineMainBase.wrap_transcript_topics(transcript.topics)
for topic in topics:
await self.push(topic)
@@ -524,8 +528,6 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
with av.open(wav_filename.as_posix()) as in_container:
in_stream = in_container.streams.audio[0]
with av.open(mp3_filename.as_posix(), "w") as out_container:
@@ -604,7 +606,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
meeting.id
)
except Exception as e:
logger.error(f"Failed to get fetch consent: {e}")
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
consent_denied = True
if not consent_denied:
@@ -627,7 +629,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
f"Deleted original Whereby recording: {recording.bucket_name}/{recording.object_key}"
)
except Exception as e:
logger.error(f"Failed to delete Whereby recording: {e}")
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(transcript, {"audio_deleted": True})
@@ -640,7 +642,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
f"Deleted processed audio from storage: {transcript.storage_audio_path}"
)
except Exception as e:
logger.error(f"Failed to delete processed audio: {e}")
logger.error(f"Failed to delete processed audio: {e}", exc_info=e)
# 3. Delete local audio files
try:
@@ -649,7 +651,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
if hasattr(transcript, "audio_wav_filename") and transcript.audio_wav_filename:
transcript.audio_wav_filename.unlink(missing_ok=True)
except Exception as e:
logger.error(f"Failed to delete local audio files: {e}")
logger.error(f"Failed to delete local audio files: {e}", exc_info=e)
logger.info("Consent cleanup done")
@@ -794,8 +796,6 @@ def pipeline_post(*, transcript_id: str):
@get_transcript
async def pipeline_process(transcript: Transcript, logger: Logger):
import av
try:
if transcript.audio_location == "storage":
await transcripts_controller.download_mp3_from_storage(transcript)

View File

@@ -16,14 +16,17 @@ During its lifecycle, it will emit the following status:
"""
import asyncio
from typing import Generic, TypeVar
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
from reflector.processors import Pipeline
PipelineMessage = TypeVar("PipelineMessage")
class PipelineRunner(BaseModel):
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
@@ -67,7 +70,7 @@ class PipelineRunner(BaseModel):
coro = self.run()
asyncio.run(coro)
async def push(self, data):
async def push(self, data: PipelineMessage):
"""
Push data to the pipeline
"""
@@ -92,7 +95,11 @@ class PipelineRunner(BaseModel):
pass
async def _add_cmd(
self, cmd: str, data, max_retries: int = 3, retry_time_limit: int = 3
self,
cmd: str,
data: PipelineMessage,
max_retries: int = 3,
retry_time_limit: int = 3,
):
"""
Enqueue a command to be executed in the runner.
@@ -143,7 +150,10 @@ class PipelineRunner(BaseModel):
cmd, data = await self._q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
if cmd.upper() == "FLUSH":
await func()
else:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception:
@@ -152,13 +162,13 @@ class PipelineRunner(BaseModel):
self._ev_done.set()
raise
async def cmd_push(self, data):
async def cmd_push(self, data: PipelineMessage):
if self._is_first_push:
await self._set_status("push")
self._is_first_push = False
await self.pipeline.push(data)
async def cmd_flush(self, data):
async def cmd_flush(self):
await self._set_status("flush")
await self.pipeline.flush()
await self._set_status("ended")

View File

@@ -2,9 +2,10 @@ import io
import re
import tempfile
from pathlib import Path
from typing import Annotated
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
from pydantic import BaseModel, Field, PrivateAttr
from reflector.redis_cache import redis_cache
@@ -48,20 +49,70 @@ class AudioFile(BaseModel):
self._path.unlink()
# non-negative seconds with float part
Seconds = Annotated[float, Field(ge=0.0, description="Time in seconds with float part")]
class Word(BaseModel):
text: str
start: float
end: float
start: Seconds
end: Seconds
speaker: int = 0
class TranscriptSegment(BaseModel):
text: str
start: float
end: float
start: Seconds
end: Seconds
speaker: int = 0
def words_to_segments(words: list[Word]) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
class Transcript(BaseModel):
translation: str | None = None
words: list[Word] = None
@@ -117,49 +168,7 @@ class Transcript(BaseModel):
return Transcript(text=self.text, translation=self.translation, words=words)
def as_segments(self) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in self.words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
return words_to_segments(self.words)
class TitleSummary(BaseModel):

View File

@@ -0,0 +1,63 @@
"""WebVTT utilities for generating subtitle files from transcript data."""
from typing import TYPE_CHECKING, Annotated
import webvtt
from reflector.processors.types import Seconds, Word, words_to_segments
if TYPE_CHECKING:
from reflector.db.transcripts import TranscriptTopic
VttTimestamp = Annotated[str, "vtt_timestamp"]
WebVTTStr = Annotated[str, "webvtt_str"]
def _seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
# lib doesn't do that
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
milliseconds = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{milliseconds:03d}"
def words_to_webvtt(words: list[Word]) -> WebVTTStr:
"""Convert words to WebVTT using existing segmentation logic."""
vtt = webvtt.WebVTT()
if not words:
return vtt.content
segments = words_to_segments(words)
for segment in segments:
text = segment.text.strip()
# lib doesn't do that
text = f"<v Speaker{segment.speaker}>{text}"
caption = webvtt.Caption(
start=_seconds_to_timestamp(segment.start),
end=_seconds_to_timestamp(segment.end),
text=text,
)
vtt.captions.append(caption)
return vtt.content
def topics_to_webvtt(topics: list["TranscriptTopic"]) -> WebVTTStr:
if not topics:
return webvtt.WebVTT().content
all_words: list[Word] = []
for topic in topics:
all_words.extend(topic.words)
# assert it's in sequence
for i in range(len(all_words) - 1):
assert (
all_words[i].start <= all_words[i + 1].start
), f"Words are not in sequence: {all_words[i].text} and {all_words[i + 1].text} are not consecutive: {all_words[i].start} > {all_words[i + 1].start}"
return words_to_webvtt(all_words)

View File

@@ -1,15 +1,29 @@
from datetime import datetime, timedelta, timezone
from typing import Annotated, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from jose import jwt
from pydantic import BaseModel, Field, field_serializer
import reflector.auth as auth
from reflector.db import database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.search import (
DEFAULT_SEARCH_LIMIT,
SearchLimit,
SearchLimitBase,
SearchOffset,
SearchOffsetBase,
SearchParameters,
SearchQuery,
SearchQueryBase,
SearchResult,
SearchTotal,
search_controller,
)
from reflector.db.transcripts import (
SourceKind,
TranscriptParticipant,
@@ -100,6 +114,21 @@ class DeletionStatus(BaseModel):
status: str
SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")]
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
SearchOffsetParam = Annotated[
SearchOffsetBase, Query(description="Number of results to skip")
]
class SearchResponse(BaseModel):
results: list[SearchResult]
total: SearchTotal
query: SearchQuery
limit: SearchLimit
offset: SearchOffset
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
@@ -107,8 +136,6 @@ async def transcripts_list(
room_id: str | None = None,
search_term: str | None = None,
):
from reflector.db import database
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
@@ -127,6 +154,39 @@ async def transcripts_list(
)
@router.get("/transcripts/search", response_model=SearchResponse)
async def transcripts_search(
q: SearchQueryParam,
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
offset: SearchOffsetParam = 0,
room_id: Optional[str] = None,
user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional)
] = None,
):
"""
Full-text search across transcript titles and content.
"""
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
search_params = SearchParameters(
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
)
results, total = await search_controller.search_transcripts(search_params)
return SearchResponse(
results=results,
total=total,
query=search_params.query_text,
limit=search_params.limit,
offset=search_params.offset,
)
@router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create(
info: CreateTranscript,
@@ -273,8 +333,8 @@ async def transcript_update(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = info.dict(exclude_unset=True)
await transcripts_controller.update(transcript, values)
return transcript
updated_transcript = await transcripts_controller.update(transcript, values)
return updated_transcript
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)