From 695d1a957d4cd862753049f9beed88836cabd5ab Mon Sep 17 00:00:00 2001 From: Igor Monadical Date: Fri, 29 Aug 2025 18:55:51 -0400 Subject: [PATCH] fix: search-logspam (#593) * fix: search-logspam * llm comment * fix tests --------- Co-authored-by: Igor Loskutov --- server/reflector/db/search.py | 62 ++++++++++++++++++--------- server/reflector/utils/string.py | 20 +++++++++ server/reflector/views/transcripts.py | 22 +++++++--- server/tests/test_search.py | 4 +- server/tests/test_search_snippets.py | 10 ++--- 5 files changed, 85 insertions(+), 33 deletions(-) create mode 100644 server/reflector/utils/string.py diff --git a/server/reflector/db/search.py b/server/reflector/db/search.py index 8ac25212..66a25ccf 100644 --- a/server/reflector/db/search.py +++ b/server/reflector/db/search.py @@ -8,12 +8,14 @@ from typing import Annotated, Any, Dict, Iterator import sqlalchemy import webvtt +from databases.interfaces import Record as DbRecord from fastapi import HTTPException from pydantic import ( BaseModel, Field, NonNegativeFloat, NonNegativeInt, + TypeAdapter, ValidationError, constr, field_serializer, @@ -24,6 +26,7 @@ from reflector.db.rooms import rooms from reflector.db.transcripts import SourceKind, transcripts from reflector.db.utils import is_postgresql from reflector.logger import logger +from reflector.utils.string import NonEmptyString, try_parse_non_empty_string DEFAULT_SEARCH_LIMIT = 20 SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include @@ -31,12 +34,13 @@ DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150) DEFAULT_MAX_SNIPPETS = NonNegativeInt(3) LONG_SUMMARY_MAX_SNIPPETS = 2 -SearchQueryBase = constr(min_length=0, strip_whitespace=True) +SearchQueryBase = constr(min_length=1, strip_whitespace=True) SearchLimitBase = Annotated[int, Field(ge=1, le=100)] SearchOffsetBase = Annotated[int, Field(ge=0)] SearchTotalBase = Annotated[int, Field(ge=0)] SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")] +search_query_adapter = TypeAdapter(SearchQuery) SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")] SearchOffset = Annotated[ SearchOffsetBase, Field(description="Number of results to skip") @@ -88,7 +92,7 @@ class WebVTTProcessor: @staticmethod def generate_snippets( webvtt_content: WebVTTContent, - query: str, + query: SearchQuery, max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS, ) -> list[str]: """Generate snippets from WebVTT content.""" @@ -125,7 +129,7 @@ class SnippetCandidate: class SearchParameters(BaseModel): """Validated search parameters for full-text search.""" - query_text: SearchQuery + query_text: SearchQuery | None = None limit: SearchLimit = DEFAULT_SEARCH_LIMIT offset: SearchOffset = 0 user_id: str | None = None @@ -199,15 +203,13 @@ class SnippetGenerator: prev_start = start @staticmethod - def count_matches(text: str, query: str) -> NonNegativeInt: + def count_matches(text: str, query: SearchQuery) -> NonNegativeInt: """Count total number of matches for a query in text.""" ZERO = NonNegativeInt(0) if not text: logger.warning("Empty text for search query in count_matches") return ZERO - if not query: - logger.warning("Empty query for search text in count_matches") - return ZERO + assert query is not None return NonNegativeInt( sum(1 for _ in SnippetGenerator.find_all_matches(text, query)) ) @@ -243,13 +245,14 @@ class SnippetGenerator: @staticmethod def generate( text: str, - query: str, + query: SearchQuery, max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH, max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS, ) -> list[str]: """Generate snippets from text.""" - if not text or not query: - logger.warning("Empty text or query for generate_snippets") + assert query is not None + if not text: + logger.warning("Empty text for generate_snippets") return [] candidates = ( @@ -270,7 +273,7 @@ class SnippetGenerator: @staticmethod def from_summary( summary: str, - query: str, + query: SearchQuery, max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS, ) -> list[str]: """Generate snippets from summary text.""" @@ -278,9 +281,9 @@ class SnippetGenerator: @staticmethod def combine_sources( - summary: str | None, + summary: NonEmptyString | None, webvtt: WebVTTContent | None, - query: str, + query: SearchQuery, max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS, ) -> tuple[list[str], NonNegativeInt]: """Combine snippets from multiple sources and return total match count. @@ -289,6 +292,11 @@ class SnippetGenerator: snippets can be empty for real in case of e.g. title match """ + + assert ( + summary is not None or webvtt is not None + ), "At least one source must be present" + webvtt_matches = 0 summary_matches = 0 @@ -355,8 +363,8 @@ class SearchController: else_=rooms.c.name, ).label("room_name"), ] - - if params.query_text: + search_query = None + if params.query_text is not None: search_query = sqlalchemy.func.websearch_to_tsquery( "english", params.query_text ) @@ -373,7 +381,9 @@ class SearchController: transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True) ) - if params.query_text: + if params.query_text is not None: + # because already initialized based on params.query_text presence above + assert search_query is not None base_query = base_query.where( transcripts.c.search_vector_en.op("@@")(search_query) ) @@ -393,7 +403,7 @@ class SearchController: transcripts.c.source_kind == params.source_kind ) - if params.query_text: + if params.query_text is not None: order_by = sqlalchemy.desc(sqlalchemy.text("rank")) else: order_by = sqlalchemy.desc(transcripts.c.created_at) @@ -407,19 +417,29 @@ class SearchController: ) total = await get_database().fetch_val(count_query) - def _process_result(r) -> SearchResult: + def _process_result(r: DbRecord) -> SearchResult: r_dict: Dict[str, Any] = dict(r) + webvtt_raw: str | None = r_dict.pop("webvtt", None) + webvtt: WebVTTContent | None if webvtt_raw: webvtt = WebVTTProcessor.parse(webvtt_raw) else: webvtt = None - long_summary: str | None = r_dict.pop("long_summary", None) + + long_summary_r: str | None = r_dict.pop("long_summary", None) + long_summary: NonEmptyString = try_parse_non_empty_string(long_summary_r) room_name: str | None = r_dict.pop("room_name", None) db_result = SearchResultDB.model_validate(r_dict) - snippets, total_match_count = SnippetGenerator.combine_sources( - long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS + at_least_one_source = webvtt is not None or long_summary is not None + has_query = params.query_text is not None + snippets, total_match_count = ( + SnippetGenerator.combine_sources( + long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS + ) + if has_query and at_least_one_source + else ([], 0) ) return SearchResult( diff --git a/server/reflector/utils/string.py b/server/reflector/utils/string.py new file mode 100644 index 00000000..08a9de78 --- /dev/null +++ b/server/reflector/utils/string.py @@ -0,0 +1,20 @@ +from typing import Annotated + +from pydantic import Field, TypeAdapter, constr + +NonEmptyStringBase = constr(min_length=1, strip_whitespace=False) +NonEmptyString = Annotated[ + NonEmptyStringBase, + Field(description="A non-empty string", min_length=1), +] +non_empty_string_adapter = TypeAdapter(NonEmptyString) + + +def parse_non_empty_string(s: str) -> NonEmptyString: + return non_empty_string_adapter.validate_python(s) + + +def try_parse_non_empty_string(s: str) -> NonEmptyString | None: + if not s: + return None + return parse_non_empty_string(s) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 594dd711..b64ecf11 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from fastapi_pagination import Page from fastapi_pagination.ext.databases import apaginate from jose import jwt -from pydantic import BaseModel, Field, field_serializer +from pydantic import BaseModel, Field, constr, field_serializer import reflector.auth as auth from reflector.db import get_database @@ -19,10 +19,10 @@ from reflector.db.search import ( SearchOffsetBase, SearchParameters, SearchQuery, - SearchQueryBase, SearchResult, SearchTotal, search_controller, + search_query_adapter, ) from reflector.db.transcripts import ( SourceKind, @@ -114,7 +114,19 @@ class DeletionStatus(BaseModel): status: str -SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")] +SearchQueryParamBase = constr(min_length=0, strip_whitespace=True) +SearchQueryParam = Annotated[ + SearchQueryParamBase, Query(description="Search query text") +] + + +# http and api standards accept "q="; we would like to handle it as the absence of query, not as "empty string query" +def parse_search_query_param(q: SearchQueryParam) -> SearchQuery | None: + if q == "": + return None + return search_query_adapter.validate_python(q) + + SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")] SearchOffsetParam = Annotated[ SearchOffsetBase, Query(description="Number of results to skip") @@ -124,7 +136,7 @@ SearchOffsetParam = Annotated[ class SearchResponse(BaseModel): results: list[SearchResult] total: SearchTotal - query: SearchQuery + query: SearchQuery | None = None limit: SearchLimit offset: SearchOffset @@ -174,7 +186,7 @@ async def transcripts_search( user_id = user["sub"] if user else None search_params = SearchParameters( - query_text=q, + query_text=parse_search_query_param(q), limit=limit, offset=offset, user_id=user_id, diff --git a/server/tests/test_search.py b/server/tests/test_search.py index 61145bf9..0f5c8923 100644 --- a/server/tests/test_search.py +++ b/server/tests/test_search.py @@ -23,7 +23,7 @@ async def test_search_postgresql_only(): assert results == [] assert total == 0 - params_empty = SearchParameters(query_text="") + params_empty = SearchParameters(query_text=None) results_empty, total_empty = await search_controller.search_transcripts( params_empty ) @@ -34,7 +34,7 @@ async def test_search_postgresql_only(): @pytest.mark.asyncio async def test_search_with_empty_query(): """Test that empty query returns all transcripts.""" - params = SearchParameters(query_text="") + params = SearchParameters(query_text=None) results, total = await search_controller.search_transcripts(params) assert isinstance(results, list) diff --git a/server/tests/test_search_snippets.py b/server/tests/test_search_snippets.py index 72267a1b..f9abd03c 100644 --- a/server/tests/test_search_snippets.py +++ b/server/tests/test_search_snippets.py @@ -1,5 +1,7 @@ """Unit tests for search snippet generation.""" +import pytest + from reflector.db.search import ( SnippetCandidate, SnippetGenerator, @@ -512,11 +514,9 @@ data visualization and data storage""" ) assert webvtt_count == 3 - snippets_empty, count_empty = SnippetGenerator.combine_sources( - None, None, "data", max_total=3 - ) - assert snippets_empty == [] - assert count_empty == 0 + # combine_sources requires at least one source to be present + with pytest.raises(AssertionError, match="At least one source must be present"): + SnippetGenerator.combine_sources(None, None, "data", max_total=3) def test_edge_cases(self): """Test edge cases for the pure functions."""