Files
reflector/server/reflector/db/search.py
Igor Loskutov d70beee51b fix: include shared rooms to search (#558)
* include shared rooms to search

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* tests vibe

* remove tests, thats too much
2025-08-21 14:52:29 -04:00

449 lines
15 KiB
Python

"""Search functionality for transcripts and other entities."""
import itertools
from dataclasses import dataclass
from datetime import datetime
from io import StringIO
from typing import Annotated, Any, Dict, Iterator
import sqlalchemy
import webvtt
from fastapi import HTTPException
from pydantic import (
BaseModel,
Field,
NonNegativeFloat,
NonNegativeInt,
ValidationError,
constr,
field_serializer,
)
from reflector.db import get_database
from reflector.db.rooms import rooms
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql
from reflector.logger import logger
DEFAULT_SEARCH_LIMIT = 20
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
LONG_SUMMARY_MAX_SNIPPETS = 2
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
SearchOffsetBase = Annotated[int, Field(ge=0)]
SearchTotalBase = Annotated[int, Field(ge=0)]
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
SearchOffset = Annotated[
SearchOffsetBase, Field(description="Number of results to skip")
]
SearchTotal = Annotated[
SearchTotalBase, Field(description="Total number of search results")
]
WEBVTT_SPEC_HEADER = "WEBVTT"
WebVTTContent = Annotated[
str,
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
]
class WebVTTProcessor:
"""Stateless processor for WebVTT content operations."""
@staticmethod
def parse(raw_content: str) -> WebVTTContent:
"""Parse WebVTT content and return it as a string."""
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
return raw_content
@staticmethod
def extract_text(webvtt_content: WebVTTContent) -> str:
"""Extract plain text from WebVTT content using webvtt library."""
try:
buffer = StringIO(webvtt_content)
vtt = webvtt.read_buffer(buffer)
return " ".join(caption.text for caption in vtt if caption.text)
except webvtt.errors.MalformedFileError as e:
logger.warning(f"Malformed WebVTT content: {e}")
return ""
except (UnicodeDecodeError, ValueError) as e:
logger.warning(f"Failed to decode WebVTT content: {e}")
return ""
except AttributeError as e:
logger.error(
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
)
return ""
except Exception as e:
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
return ""
@staticmethod
def generate_snippets(
webvtt_content: WebVTTContent,
query: str,
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate snippets from WebVTT content."""
return SnippetGenerator.generate(
WebVTTProcessor.extract_text(webvtt_content),
query,
max_snippets=max_snippets,
)
@dataclass(frozen=True)
class SnippetCandidate:
"""Represents a candidate snippet with its position."""
_text: str
start: NonNegativeInt
_original_text_length: int
@property
def end(self) -> NonNegativeInt:
"""Calculate end position from start and raw text length."""
return self.start + len(self._text)
def text(self) -> str:
"""Get display text with ellipses added if needed."""
result = self._text.strip()
if self.start > 0:
result = "..." + result
if self.end < self._original_text_length:
result = result + "..."
return result
class SearchParameters(BaseModel):
"""Validated search parameters for full-text search."""
query_text: SearchQuery
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
offset: SearchOffset = 0
user_id: str | None = None
room_id: str | None = None
source_kind: SourceKind | None = None
class SearchResultDB(BaseModel):
"""Intermediate model for validating raw database results."""
id: str = Field(..., min_length=1)
created_at: datetime
status: str = Field(..., min_length=1)
duration: float | None = Field(None, ge=0)
user_id: str | None = None
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
rank: float = Field(..., ge=0, le=1)
class SearchResult(BaseModel):
"""Public search result model with computed fields."""
id: str = Field(..., min_length=1)
title: str | None = None
user_id: str | None = None
room_id: str | None = None
room_name: str | None = None
source_kind: SourceKind
created_at: datetime
status: str = Field(..., min_length=1)
rank: float = Field(..., ge=0, le=1)
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
search_snippets: list[str] = Field(
description="Text snippets around search matches"
)
total_match_count: NonNegativeInt = Field(
default=0, description="Total number of matches found in the transcript"
)
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
if dt.tzinfo is None:
return dt.isoformat() + "Z"
return dt.isoformat()
class SnippetGenerator:
"""Stateless generator for text snippets and match operations."""
@staticmethod
def find_all_matches(text: str, query: str) -> Iterator[int]:
"""Generate all match positions for a query in text."""
if not text:
logger.warning("Empty text for search query in find_all_matches")
return
if not query:
logger.warning("Empty query for search text in find_all_matches")
return
text_lower = text.lower()
query_lower = query.lower()
start = 0
prev_start = start
while (pos := text_lower.find(query_lower, start)) != -1:
yield pos
start = pos + len(query_lower)
if start <= prev_start:
raise ValueError("panic! find_all_matches is not incremental")
prev_start = start
@staticmethod
def count_matches(text: str, query: str) -> NonNegativeInt:
"""Count total number of matches for a query in text."""
ZERO = NonNegativeInt(0)
if not text:
logger.warning("Empty text for search query in count_matches")
return ZERO
if not query:
logger.warning("Empty query for search text in count_matches")
return ZERO
return NonNegativeInt(
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
)
@staticmethod
def create_snippet(
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
) -> SnippetCandidate:
"""Create a snippet from a match position."""
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
snippet_text = text[snippet_start:snippet_end]
return SnippetCandidate(
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
)
@staticmethod
def filter_non_overlapping(
candidates: Iterator[SnippetCandidate],
) -> Iterator[str]:
"""Filter out overlapping snippets and return only display text."""
last_end = 0
for candidate in candidates:
display_text = candidate.text()
# it means that next overlapping snippets simply don't get included
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
if candidate.start >= last_end and display_text:
yield display_text
last_end = candidate.end
@staticmethod
def generate(
text: str,
query: str,
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate snippets from text."""
if not text or not query:
logger.warning("Empty text or query for generate_snippets")
return []
candidates = (
SnippetGenerator.create_snippet(text, pos, max_length)
for pos in SnippetGenerator.find_all_matches(text, query)
)
filtered = SnippetGenerator.filter_non_overlapping(candidates)
snippets = list(itertools.islice(filtered, max_snippets))
# Fallback to first word search if no full matches
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
if not snippets and " " in query:
first_word = query.split()[0]
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
return snippets
@staticmethod
def from_summary(
summary: str,
query: str,
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
) -> list[str]:
"""Generate snippets from summary text."""
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
@staticmethod
def combine_sources(
summary: str | None,
webvtt: WebVTTContent | None,
query: str,
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> tuple[list[str], NonNegativeInt]:
"""Combine snippets from multiple sources and return total match count.
Returns (snippets, total_match_count) tuple.
snippets can be empty for real in case of e.g. title match
"""
webvtt_matches = 0
summary_matches = 0
if webvtt:
webvtt_text = WebVTTProcessor.extract_text(webvtt)
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
if summary:
summary_matches = SnippetGenerator.count_matches(summary, query)
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
summary_snippets = (
SnippetGenerator.from_summary(summary, query) if summary else []
)
if len(summary_snippets) >= max_total:
return summary_snippets[:max_total], total_matches
remaining = max_total - len(summary_snippets)
webvtt_snippets = (
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
if webvtt
else []
)
return summary_snippets + webvtt_snippets, total_matches
class SearchController:
"""Controller for search operations across different entities."""
@classmethod
async def search_transcripts(
cls, params: SearchParameters
) -> tuple[list[SearchResult], int]:
"""
Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count).
"""
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
base_columns = [
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
transcripts.c.long_summary,
sqlalchemy.case(
(
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
"Deleted Room",
),
else_=rooms.c.name,
).label("room_name"),
]
if params.query_text:
search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text
)
rank_column = sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank")
else:
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
columns = base_columns + [rank_column]
base_query = sqlalchemy.select(columns).select_from(
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
)
if params.query_text:
base_query = base_query.where(
transcripts.c.search_vector_en.op("@@")(search_query)
)
if params.user_id:
base_query = base_query.where(
sqlalchemy.or_(
transcripts.c.user_id == params.user_id, rooms.c.is_shared
)
)
else:
base_query = base_query.where(rooms.c.is_shared)
if params.room_id:
base_query = base_query.where(transcripts.c.room_id == params.room_id)
if params.source_kind:
base_query = base_query.where(
transcripts.c.source_kind == params.source_kind
)
if params.query_text:
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
else:
order_by = sqlalchemy.desc(transcripts.c.created_at)
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
rs = await get_database().fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
total = await get_database().fetch_val(count_query)
def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt_raw: str | None = r_dict.pop("webvtt", None)
if webvtt_raw:
webvtt = WebVTTProcessor.parse(webvtt_raw)
else:
webvtt = None
long_summary: str | None = r_dict.pop("long_summary", None)
room_name: str | None = r_dict.pop("room_name", None)
db_result = SearchResultDB.model_validate(r_dict)
snippets, total_match_count = SnippetGenerator.combine_sources(
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
)
return SearchResult(
**db_result.model_dump(),
room_name=room_name,
search_snippets=snippets,
total_match_count=total_match_count,
)
try:
results = [_process_result(r) for r in rs]
except ValidationError as e:
logger.error(f"Invalid search result data: {e}", exc_info=True)
raise HTTPException(
status_code=500, detail="Internal search result data consistency error"
)
except Exception as e:
logger.error(f"Error processing search results: {e}", exc_info=True)
raise
return results, total
search_controller = SearchController()
webvtt_processor = WebVTTProcessor()
snippet_generator = SnippetGenerator()