fix: search-logspam (#593)

* fix: search-logspam

* llm comment

* fix tests

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
Igor Monadical
2025-08-29 18:55:51 -04:00
committed by GitHub
parent ccffdba75b
commit 695d1a957d
5 changed files with 85 additions and 33 deletions

View File

@@ -8,12 +8,14 @@ from typing import Annotated, Any, Dict, Iterator
import sqlalchemy import sqlalchemy
import webvtt import webvtt
from databases.interfaces import Record as DbRecord
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
Field, Field,
NonNegativeFloat, NonNegativeFloat,
NonNegativeInt, NonNegativeInt,
TypeAdapter,
ValidationError, ValidationError,
constr, constr,
field_serializer, field_serializer,
@@ -24,6 +26,7 @@ from reflector.db.rooms import rooms
from reflector.db.transcripts import SourceKind, transcripts from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql from reflector.db.utils import is_postgresql
from reflector.logger import logger from reflector.logger import logger
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
DEFAULT_SEARCH_LIMIT = 20 DEFAULT_SEARCH_LIMIT = 20
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
@@ -31,12 +34,13 @@ DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3) DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
LONG_SUMMARY_MAX_SNIPPETS = 2 LONG_SUMMARY_MAX_SNIPPETS = 2
SearchQueryBase = constr(min_length=0, strip_whitespace=True) SearchQueryBase = constr(min_length=1, strip_whitespace=True)
SearchLimitBase = Annotated[int, Field(ge=1, le=100)] SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
SearchOffsetBase = Annotated[int, Field(ge=0)] SearchOffsetBase = Annotated[int, Field(ge=0)]
SearchTotalBase = Annotated[int, Field(ge=0)] SearchTotalBase = Annotated[int, Field(ge=0)]
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")] SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
search_query_adapter = TypeAdapter(SearchQuery)
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")] SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
SearchOffset = Annotated[ SearchOffset = Annotated[
SearchOffsetBase, Field(description="Number of results to skip") SearchOffsetBase, Field(description="Number of results to skip")
@@ -88,7 +92,7 @@ class WebVTTProcessor:
@staticmethod @staticmethod
def generate_snippets( def generate_snippets(
webvtt_content: WebVTTContent, webvtt_content: WebVTTContent,
query: str, query: SearchQuery,
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS, max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> list[str]: ) -> list[str]:
"""Generate snippets from WebVTT content.""" """Generate snippets from WebVTT content."""
@@ -125,7 +129,7 @@ class SnippetCandidate:
class SearchParameters(BaseModel): class SearchParameters(BaseModel):
"""Validated search parameters for full-text search.""" """Validated search parameters for full-text search."""
query_text: SearchQuery query_text: SearchQuery | None = None
limit: SearchLimit = DEFAULT_SEARCH_LIMIT limit: SearchLimit = DEFAULT_SEARCH_LIMIT
offset: SearchOffset = 0 offset: SearchOffset = 0
user_id: str | None = None user_id: str | None = None
@@ -199,15 +203,13 @@ class SnippetGenerator:
prev_start = start prev_start = start
@staticmethod @staticmethod
def count_matches(text: str, query: str) -> NonNegativeInt: def count_matches(text: str, query: SearchQuery) -> NonNegativeInt:
"""Count total number of matches for a query in text.""" """Count total number of matches for a query in text."""
ZERO = NonNegativeInt(0) ZERO = NonNegativeInt(0)
if not text: if not text:
logger.warning("Empty text for search query in count_matches") logger.warning("Empty text for search query in count_matches")
return ZERO return ZERO
if not query: assert query is not None
logger.warning("Empty query for search text in count_matches")
return ZERO
return NonNegativeInt( return NonNegativeInt(
sum(1 for _ in SnippetGenerator.find_all_matches(text, query)) sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
) )
@@ -243,13 +245,14 @@ class SnippetGenerator:
@staticmethod @staticmethod
def generate( def generate(
text: str, text: str,
query: str, query: SearchQuery,
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH, max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS, max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> list[str]: ) -> list[str]:
"""Generate snippets from text.""" """Generate snippets from text."""
if not text or not query: assert query is not None
logger.warning("Empty text or query for generate_snippets") if not text:
logger.warning("Empty text for generate_snippets")
return [] return []
candidates = ( candidates = (
@@ -270,7 +273,7 @@ class SnippetGenerator:
@staticmethod @staticmethod
def from_summary( def from_summary(
summary: str, summary: str,
query: str, query: SearchQuery,
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS, max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
) -> list[str]: ) -> list[str]:
"""Generate snippets from summary text.""" """Generate snippets from summary text."""
@@ -278,9 +281,9 @@ class SnippetGenerator:
@staticmethod @staticmethod
def combine_sources( def combine_sources(
summary: str | None, summary: NonEmptyString | None,
webvtt: WebVTTContent | None, webvtt: WebVTTContent | None,
query: str, query: SearchQuery,
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS, max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
) -> tuple[list[str], NonNegativeInt]: ) -> tuple[list[str], NonNegativeInt]:
"""Combine snippets from multiple sources and return total match count. """Combine snippets from multiple sources and return total match count.
@@ -289,6 +292,11 @@ class SnippetGenerator:
snippets can be empty for real in case of e.g. title match snippets can be empty for real in case of e.g. title match
""" """
assert (
summary is not None or webvtt is not None
), "At least one source must be present"
webvtt_matches = 0 webvtt_matches = 0
summary_matches = 0 summary_matches = 0
@@ -355,8 +363,8 @@ class SearchController:
else_=rooms.c.name, else_=rooms.c.name,
).label("room_name"), ).label("room_name"),
] ]
search_query = None
if params.query_text: if params.query_text is not None:
search_query = sqlalchemy.func.websearch_to_tsquery( search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text "english", params.query_text
) )
@@ -373,7 +381,9 @@ class SearchController:
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True) transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
) )
if params.query_text: if params.query_text is not None:
# because already initialized based on params.query_text presence above
assert search_query is not None
base_query = base_query.where( base_query = base_query.where(
transcripts.c.search_vector_en.op("@@")(search_query) transcripts.c.search_vector_en.op("@@")(search_query)
) )
@@ -393,7 +403,7 @@ class SearchController:
transcripts.c.source_kind == params.source_kind transcripts.c.source_kind == params.source_kind
) )
if params.query_text: if params.query_text is not None:
order_by = sqlalchemy.desc(sqlalchemy.text("rank")) order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
else: else:
order_by = sqlalchemy.desc(transcripts.c.created_at) order_by = sqlalchemy.desc(transcripts.c.created_at)
@@ -407,20 +417,30 @@ class SearchController:
) )
total = await get_database().fetch_val(count_query) total = await get_database().fetch_val(count_query)
def _process_result(r) -> SearchResult: def _process_result(r: DbRecord) -> SearchResult:
r_dict: Dict[str, Any] = dict(r) r_dict: Dict[str, Any] = dict(r)
webvtt_raw: str | None = r_dict.pop("webvtt", None) webvtt_raw: str | None = r_dict.pop("webvtt", None)
webvtt: WebVTTContent | None
if webvtt_raw: if webvtt_raw:
webvtt = WebVTTProcessor.parse(webvtt_raw) webvtt = WebVTTProcessor.parse(webvtt_raw)
else: else:
webvtt = None webvtt = None
long_summary: str | None = r_dict.pop("long_summary", None)
long_summary_r: str | None = r_dict.pop("long_summary", None)
long_summary: NonEmptyString = try_parse_non_empty_string(long_summary_r)
room_name: str | None = r_dict.pop("room_name", None) room_name: str | None = r_dict.pop("room_name", None)
db_result = SearchResultDB.model_validate(r_dict) db_result = SearchResultDB.model_validate(r_dict)
snippets, total_match_count = SnippetGenerator.combine_sources( at_least_one_source = webvtt is not None or long_summary is not None
has_query = params.query_text is not None
snippets, total_match_count = (
SnippetGenerator.combine_sources(
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
) )
if has_query and at_least_one_source
else ([], 0)
)
return SearchResult( return SearchResult(
**db_result.model_dump(), **db_result.model_dump(),

View File

@@ -0,0 +1,20 @@
from typing import Annotated
from pydantic import Field, TypeAdapter, constr
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
NonEmptyString = Annotated[
NonEmptyStringBase,
Field(description="A non-empty string", min_length=1),
]
non_empty_string_adapter = TypeAdapter(NonEmptyString)
def parse_non_empty_string(s: str) -> NonEmptyString:
return non_empty_string_adapter.validate_python(s)
def try_parse_non_empty_string(s: str) -> NonEmptyString | None:
if not s:
return None
return parse_non_empty_string(s)

View File

@@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page from fastapi_pagination import Page
from fastapi_pagination.ext.databases import apaginate from fastapi_pagination.ext.databases import apaginate
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field, field_serializer from pydantic import BaseModel, Field, constr, field_serializer
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_database from reflector.db import get_database
@@ -19,10 +19,10 @@ from reflector.db.search import (
SearchOffsetBase, SearchOffsetBase,
SearchParameters, SearchParameters,
SearchQuery, SearchQuery,
SearchQueryBase,
SearchResult, SearchResult,
SearchTotal, SearchTotal,
search_controller, search_controller,
search_query_adapter,
) )
from reflector.db.transcripts import ( from reflector.db.transcripts import (
SourceKind, SourceKind,
@@ -114,7 +114,19 @@ class DeletionStatus(BaseModel):
status: str status: str
SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")] SearchQueryParamBase = constr(min_length=0, strip_whitespace=True)
SearchQueryParam = Annotated[
SearchQueryParamBase, Query(description="Search query text")
]
# http and api standards accept "q="; we would like to handle it as the absence of query, not as "empty string query"
def parse_search_query_param(q: SearchQueryParam) -> SearchQuery | None:
if q == "":
return None
return search_query_adapter.validate_python(q)
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")] SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
SearchOffsetParam = Annotated[ SearchOffsetParam = Annotated[
SearchOffsetBase, Query(description="Number of results to skip") SearchOffsetBase, Query(description="Number of results to skip")
@@ -124,7 +136,7 @@ SearchOffsetParam = Annotated[
class SearchResponse(BaseModel): class SearchResponse(BaseModel):
results: list[SearchResult] results: list[SearchResult]
total: SearchTotal total: SearchTotal
query: SearchQuery query: SearchQuery | None = None
limit: SearchLimit limit: SearchLimit
offset: SearchOffset offset: SearchOffset
@@ -174,7 +186,7 @@ async def transcripts_search(
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
search_params = SearchParameters( search_params = SearchParameters(
query_text=q, query_text=parse_search_query_param(q),
limit=limit, limit=limit,
offset=offset, offset=offset,
user_id=user_id, user_id=user_id,

View File

@@ -23,7 +23,7 @@ async def test_search_postgresql_only():
assert results == [] assert results == []
assert total == 0 assert total == 0
params_empty = SearchParameters(query_text="") params_empty = SearchParameters(query_text=None)
results_empty, total_empty = await search_controller.search_transcripts( results_empty, total_empty = await search_controller.search_transcripts(
params_empty params_empty
) )
@@ -34,7 +34,7 @@ async def test_search_postgresql_only():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_empty_query(): async def test_search_with_empty_query():
"""Test that empty query returns all transcripts.""" """Test that empty query returns all transcripts."""
params = SearchParameters(query_text="") params = SearchParameters(query_text=None)
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(params)
assert isinstance(results, list) assert isinstance(results, list)

View File

@@ -1,5 +1,7 @@
"""Unit tests for search snippet generation.""" """Unit tests for search snippet generation."""
import pytest
from reflector.db.search import ( from reflector.db.search import (
SnippetCandidate, SnippetCandidate,
SnippetGenerator, SnippetGenerator,
@@ -512,11 +514,9 @@ data visualization and data storage"""
) )
assert webvtt_count == 3 assert webvtt_count == 3
snippets_empty, count_empty = SnippetGenerator.combine_sources( # combine_sources requires at least one source to be present
None, None, "data", max_total=3 with pytest.raises(AssertionError, match="At least one source must be present"):
) SnippetGenerator.combine_sources(None, None, "data", max_total=3)
assert snippets_empty == []
assert count_empty == 0
def test_edge_cases(self): def test_edge_cases(self):
"""Test edge cases for the pure functions.""" """Test edge cases for the pure functions."""