feat: search backend (#537)

* docs: transient docs

* chore: cleanup

* webvtt WIP

* webvtt field

* chore: webvtt tests comments

* chore: remove useless tests

* feat: search TASK.md

* feat: full text search by title/webvtt

* chore: search api task

* feat: search api

* feat: search API

* chore: rm task md

* chore: roll back unnecessary validators

* chore: pr review WIP

* chore: pr review WIP

* chore: pr review

* chore: top imports

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* chore: lint

* chore: lint

* fix: db datetime definitions

* fix: flush() params

* fix: update transcript mutability expectation / test

* fix: update transcript mutability expectation / test

* chore: auto review

* chore: new controller extraction

* chore: new controller extraction

* chore: cleanup

* chore: review WIP

* chore: pr WIP

* chore: remove ci lint

* chore: openapi regeneration

* chore: openapi regeneration

* chore: postgres test doc

* fix: .dockerignore for arm binaries

* fix: .dockerignore for arm binaries

* fix: cap test loops

* fix: cap test loops

* fix: cap test loops

* fix: get_transcript_topics

* chore: remove flow.md docs and claude guidance

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc
This commit is contained in:
Igor Loskutov
2025-08-13 10:03:38 -04:00
committed by GitHub
parent a42ed12982
commit 6fb5cb21c2
29 changed files with 3213 additions and 1493 deletions

View File

@@ -0,0 +1,233 @@
"""Search functionality for transcripts and other entities."""
import logging
from datetime import datetime
from io import StringIO
from typing import Annotated, Any, Dict
import sqlalchemy
import webvtt
from pydantic import BaseModel, Field, constr, field_serializer
from reflector.db import database
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql
logger = logging.getLogger(__name__)
DEFAULT_SEARCH_LIMIT = 20
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
DEFAULT_SNIPPET_MAX_LENGTH = 150
DEFAULT_MAX_SNIPPETS = 3
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
SearchOffsetBase = Annotated[int, Field(ge=0)]
SearchTotalBase = Annotated[int, Field(ge=0)]
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
SearchOffset = Annotated[
SearchOffsetBase, Field(description="Number of results to skip")
]
SearchTotal = Annotated[
SearchTotalBase, Field(description="Total number of search results")
]
class SearchParameters(BaseModel):
"""Validated search parameters for full-text search."""
query_text: SearchQuery
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
offset: SearchOffset = 0
user_id: str | None = None
room_id: str | None = None
class SearchResultDB(BaseModel):
"""Intermediate model for validating raw database results."""
id: str = Field(..., min_length=1)
created_at: datetime
status: str = Field(..., min_length=1)
duration: float | None = Field(None, ge=0)
user_id: str | None = None
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
rank: float = Field(..., ge=0, le=1)
class SearchResult(BaseModel):
"""Public search result model with computed fields."""
id: str = Field(..., min_length=1)
title: str | None = None
user_id: str | None = None
room_id: str | None = None
created_at: datetime
status: str = Field(..., min_length=1)
rank: float = Field(..., ge=0, le=1)
duration: float | None = Field(..., ge=0, description="Duration in seconds")
search_snippets: list[str] = Field(
description="Text snippets around search matches"
)
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
if dt.tzinfo is None:
return dt.isoformat() + "Z"
return dt.isoformat()
class SearchController:
"""Controller for search operations across different entities."""
@staticmethod
def _extract_webvtt_text(webvtt_content: str) -> str:
"""Extract plain text from WebVTT content using webvtt library."""
if not webvtt_content:
return ""
try:
buffer = StringIO(webvtt_content)
vtt = webvtt.read_buffer(buffer)
return " ".join(caption.text for caption in vtt if caption.text)
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
return ""
except AttributeError as e:
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
return ""
@staticmethod
def _generate_snippets(
text: str,
q: SearchQuery,
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: int = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate multiple snippets around all occurrences of search term."""
if not text or not q:
return []
snippets = []
lower_text = text.lower()
search_lower = q.lower()
last_snippet_end = 0
start_pos = 0
while len(snippets) < max_snippets:
match_pos = lower_text.find(search_lower, start_pos)
if match_pos == -1:
if not snippets and search_lower.split():
first_word = search_lower.split()[0]
match_pos = lower_text.find(first_word, start_pos)
if match_pos == -1:
break
else:
break
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
snippet_end = min(
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
)
if snippet_start < last_snippet_end:
start_pos = match_pos + len(search_lower)
continue
snippet = text[snippet_start:snippet_end]
if snippet_start > 0:
snippet = "..." + snippet
if snippet_end < len(text):
snippet = snippet + "..."
snippet = snippet.strip()
if snippet:
snippets.append(snippet)
last_snippet_end = snippet_end
start_pos = match_pos + len(search_lower)
if start_pos >= len(text):
break
return snippets
@classmethod
async def search_transcripts(
cls, params: SearchParameters
) -> tuple[list[SearchResult], int]:
"""
Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count).
"""
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text
)
base_query = sqlalchemy.select(
[
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank"),
]
).where(transcripts.c.search_vector_en.op("@@")(search_query))
if params.user_id:
base_query = base_query.where(transcripts.c.user_id == params.user_id)
if params.room_id:
base_query = base_query.where(transcripts.c.room_id == params.room_id)
query = (
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
.limit(params.limit)
.offset(params.offset)
)
rs = await database.fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
total = await database.fetch_val(count_query)
def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt: str | None = r_dict.pop("webvtt", None)
db_result = SearchResultDB.model_validate(r_dict)
snippets = []
if webvtt:
plain_text = cls._extract_webvtt_text(webvtt)
snippets = cls._generate_snippets(plain_text, params.query_text)
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
results = [_process_result(r) for r in rs]
return results, total
search_controller = SearchController()

View File

@@ -1,9 +1,10 @@
import enum
import json
import logging
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
@@ -11,13 +12,19 @@ import sqlalchemy
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_
from reflector.db import database, metadata
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.storage import get_transcripts_storage
from reflector.utils import generate_uuid4
from reflector.utils.webvtt import topics_to_webvtt
logger = logging.getLogger(__name__)
class SourceKind(enum.StrEnum):
@@ -76,6 +83,7 @@ transcripts = sqlalchemy.Table(
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
@@ -83,6 +91,29 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
)
# Add PostgreSQL-specific full-text search column
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
if is_postgresql():
transcripts.append_column(
sqlalchemy.Column(
"search_vector_en",
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
persisted=True,
),
)
)
# Add GIN index for the search vector
transcripts.append_constraint(
sqlalchemy.Index(
"idx_transcript_search_vector_en",
"search_vector_en",
postgresql_using="gin",
)
)
def generate_transcript_name() -> str:
now = datetime.now(timezone.utc)
@@ -147,14 +178,18 @@ class TranscriptParticipant(BaseModel):
class Transcript(BaseModel):
"""Full transcript model with all fields."""
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
locked: bool = False
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
@@ -168,9 +203,8 @@ class Transcript(BaseModel):
meeting_id: str | None = None
recording_id: str | None = None
zulip_message_id: int | None = None
source_kind: SourceKind
audio_deleted: bool | None = None
room_id: str | None = None
webvtt: str | None = None
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
@@ -271,10 +305,12 @@ class Transcript(BaseModel):
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
# TODO don't import app in db
from reflector.app import app # noqa: PLC0415
# TODO a util + don''t import views in db
from reflector.views.transcripts import create_access_token # noqa: PLC0415
path = app.url_path_for(
"transcript_get_audio_mp3",
@@ -335,7 +371,6 @@ class TranscriptController:
- `room_id`: filter transcripts by room ID
- `search_term`: filter transcripts by search term
"""
from reflector.db.rooms import rooms
query = transcripts.select().join(
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
@@ -502,10 +537,17 @@ class TranscriptController:
await database.execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict, mutate=True):
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self, transcript: Transcript, values: dict, mutate=False
) -> Transcript:
"""
Update a transcript fields with key/values in values
Update a transcript fields with key/values in values.
Returns a copy of the transcript with updated values.
"""
values = TranscriptController._handle_topics_update(values)
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
@@ -516,6 +558,28 @@ class TranscriptController:
for key, value in values.items():
setattr(transcript, key, value)
updated_transcript = transcript.model_copy(update=values)
return updated_transcript
@staticmethod
def _handle_topics_update(values: dict) -> dict:
"""Auto-update WebVTT when topics are updated."""
if values.get("webvtt") is not None:
logger.warn("trying to update read-only webvtt column")
pass
topics_data = values.get("topics")
if topics_data is None:
return values
return {
**values,
"webvtt": topics_to_webvtt(
[TranscriptTopic(**topic_dict) for topic_dict in topics_data]
),
}
async def remove_by_id(
self,
transcript_id: str,
@@ -558,11 +622,7 @@ class TranscriptController:
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(
transcript,
{"events": transcript.events_dump()},
mutate=False,
)
await self.update(transcript, {"events": transcript.events_dump()})
return resp
async def upsert_topic(
@@ -574,11 +634,7 @@ class TranscriptController:
Upsert topics to a transcript
"""
transcript.upsert_topic(topic)
await self.update(
transcript,
{"topics": transcript.topics_dump()},
mutate=False,
)
await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
@@ -603,7 +659,8 @@ class TranscriptController:
)
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
# mutates transcript argument
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
@@ -627,11 +684,7 @@ class TranscriptController:
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
await self.update(transcript, {"participants": transcript.participants_dump()})
return result
async def delete_participant(
@@ -643,11 +696,7 @@ class TranscriptController:
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
await self.update(transcript, {"participants": transcript.participants_dump()})
transcripts_controller = TranscriptController()

View File

@@ -0,0 +1,7 @@
"""Database utility functions."""
from reflector.db import database
def is_postgresql() -> bool:
return database.url.scheme and database.url.scheme.startswith('postgresql')