mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
fix: search-logspam (#593)
* fix: search-logspam * llm comment * fix tests --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
@@ -8,12 +8,14 @@ from typing import Annotated, Any, Dict, Iterator
|
|||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import webvtt
|
import webvtt
|
||||||
|
from databases.interfaces import Record as DbRecord
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
NonNegativeFloat,
|
NonNegativeFloat,
|
||||||
NonNegativeInt,
|
NonNegativeInt,
|
||||||
|
TypeAdapter,
|
||||||
ValidationError,
|
ValidationError,
|
||||||
constr,
|
constr,
|
||||||
field_serializer,
|
field_serializer,
|
||||||
@@ -24,6 +26,7 @@ from reflector.db.rooms import rooms
|
|||||||
from reflector.db.transcripts import SourceKind, transcripts
|
from reflector.db.transcripts import SourceKind, transcripts
|
||||||
from reflector.db.utils import is_postgresql
|
from reflector.db.utils import is_postgresql
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
|
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
|
||||||
|
|
||||||
DEFAULT_SEARCH_LIMIT = 20
|
DEFAULT_SEARCH_LIMIT = 20
|
||||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||||
@@ -31,12 +34,13 @@ DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
|||||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||||
|
|
||||||
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
|
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||||
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||||
|
|
||||||
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
||||||
|
search_query_adapter = TypeAdapter(SearchQuery)
|
||||||
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
||||||
SearchOffset = Annotated[
|
SearchOffset = Annotated[
|
||||||
SearchOffsetBase, Field(description="Number of results to skip")
|
SearchOffsetBase, Field(description="Number of results to skip")
|
||||||
@@ -88,7 +92,7 @@ class WebVTTProcessor:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_snippets(
|
def generate_snippets(
|
||||||
webvtt_content: WebVTTContent,
|
webvtt_content: WebVTTContent,
|
||||||
query: str,
|
query: SearchQuery,
|
||||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Generate snippets from WebVTT content."""
|
"""Generate snippets from WebVTT content."""
|
||||||
@@ -125,7 +129,7 @@ class SnippetCandidate:
|
|||||||
class SearchParameters(BaseModel):
|
class SearchParameters(BaseModel):
|
||||||
"""Validated search parameters for full-text search."""
|
"""Validated search parameters for full-text search."""
|
||||||
|
|
||||||
query_text: SearchQuery
|
query_text: SearchQuery | None = None
|
||||||
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
||||||
offset: SearchOffset = 0
|
offset: SearchOffset = 0
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
@@ -199,15 +203,13 @@ class SnippetGenerator:
|
|||||||
prev_start = start
|
prev_start = start
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def count_matches(text: str, query: str) -> NonNegativeInt:
|
def count_matches(text: str, query: SearchQuery) -> NonNegativeInt:
|
||||||
"""Count total number of matches for a query in text."""
|
"""Count total number of matches for a query in text."""
|
||||||
ZERO = NonNegativeInt(0)
|
ZERO = NonNegativeInt(0)
|
||||||
if not text:
|
if not text:
|
||||||
logger.warning("Empty text for search query in count_matches")
|
logger.warning("Empty text for search query in count_matches")
|
||||||
return ZERO
|
return ZERO
|
||||||
if not query:
|
assert query is not None
|
||||||
logger.warning("Empty query for search text in count_matches")
|
|
||||||
return ZERO
|
|
||||||
return NonNegativeInt(
|
return NonNegativeInt(
|
||||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||||
)
|
)
|
||||||
@@ -243,13 +245,14 @@ class SnippetGenerator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(
|
def generate(
|
||||||
text: str,
|
text: str,
|
||||||
query: str,
|
query: SearchQuery,
|
||||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Generate snippets from text."""
|
"""Generate snippets from text."""
|
||||||
if not text or not query:
|
assert query is not None
|
||||||
logger.warning("Empty text or query for generate_snippets")
|
if not text:
|
||||||
|
logger.warning("Empty text for generate_snippets")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
candidates = (
|
candidates = (
|
||||||
@@ -270,7 +273,7 @@ class SnippetGenerator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_summary(
|
def from_summary(
|
||||||
summary: str,
|
summary: str,
|
||||||
query: str,
|
query: SearchQuery,
|
||||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Generate snippets from summary text."""
|
"""Generate snippets from summary text."""
|
||||||
@@ -278,9 +281,9 @@ class SnippetGenerator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_sources(
|
def combine_sources(
|
||||||
summary: str | None,
|
summary: NonEmptyString | None,
|
||||||
webvtt: WebVTTContent | None,
|
webvtt: WebVTTContent | None,
|
||||||
query: str,
|
query: SearchQuery,
|
||||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||||
) -> tuple[list[str], NonNegativeInt]:
|
) -> tuple[list[str], NonNegativeInt]:
|
||||||
"""Combine snippets from multiple sources and return total match count.
|
"""Combine snippets from multiple sources and return total match count.
|
||||||
@@ -289,6 +292,11 @@ class SnippetGenerator:
|
|||||||
|
|
||||||
snippets can be empty for real in case of e.g. title match
|
snippets can be empty for real in case of e.g. title match
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
summary is not None or webvtt is not None
|
||||||
|
), "At least one source must be present"
|
||||||
|
|
||||||
webvtt_matches = 0
|
webvtt_matches = 0
|
||||||
summary_matches = 0
|
summary_matches = 0
|
||||||
|
|
||||||
@@ -355,8 +363,8 @@ class SearchController:
|
|||||||
else_=rooms.c.name,
|
else_=rooms.c.name,
|
||||||
).label("room_name"),
|
).label("room_name"),
|
||||||
]
|
]
|
||||||
|
search_query = None
|
||||||
if params.query_text:
|
if params.query_text is not None:
|
||||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||||
"english", params.query_text
|
"english", params.query_text
|
||||||
)
|
)
|
||||||
@@ -373,7 +381,9 @@ class SearchController:
|
|||||||
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.query_text:
|
if params.query_text is not None:
|
||||||
|
# because already initialized based on params.query_text presence above
|
||||||
|
assert search_query is not None
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
transcripts.c.search_vector_en.op("@@")(search_query)
|
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||||
)
|
)
|
||||||
@@ -393,7 +403,7 @@ class SearchController:
|
|||||||
transcripts.c.source_kind == params.source_kind
|
transcripts.c.source_kind == params.source_kind
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.query_text:
|
if params.query_text is not None:
|
||||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||||
else:
|
else:
|
||||||
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||||
@@ -407,19 +417,29 @@ class SearchController:
|
|||||||
)
|
)
|
||||||
total = await get_database().fetch_val(count_query)
|
total = await get_database().fetch_val(count_query)
|
||||||
|
|
||||||
def _process_result(r) -> SearchResult:
|
def _process_result(r: DbRecord) -> SearchResult:
|
||||||
r_dict: Dict[str, Any] = dict(r)
|
r_dict: Dict[str, Any] = dict(r)
|
||||||
|
|
||||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||||
|
webvtt: WebVTTContent | None
|
||||||
if webvtt_raw:
|
if webvtt_raw:
|
||||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||||
else:
|
else:
|
||||||
webvtt = None
|
webvtt = None
|
||||||
long_summary: str | None = r_dict.pop("long_summary", None)
|
|
||||||
|
long_summary_r: str | None = r_dict.pop("long_summary", None)
|
||||||
|
long_summary: NonEmptyString = try_parse_non_empty_string(long_summary_r)
|
||||||
room_name: str | None = r_dict.pop("room_name", None)
|
room_name: str | None = r_dict.pop("room_name", None)
|
||||||
db_result = SearchResultDB.model_validate(r_dict)
|
db_result = SearchResultDB.model_validate(r_dict)
|
||||||
|
|
||||||
snippets, total_match_count = SnippetGenerator.combine_sources(
|
at_least_one_source = webvtt is not None or long_summary is not None
|
||||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
has_query = params.query_text is not None
|
||||||
|
snippets, total_match_count = (
|
||||||
|
SnippetGenerator.combine_sources(
|
||||||
|
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||||
|
)
|
||||||
|
if has_query and at_least_one_source
|
||||||
|
else ([], 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
|
|||||||
20
server/reflector/utils/string.py
Normal file
20
server/reflector/utils/string.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from pydantic import Field, TypeAdapter, constr
|
||||||
|
|
||||||
|
NonEmptyStringBase = constr(min_length=1, strip_whitespace=False)
|
||||||
|
NonEmptyString = Annotated[
|
||||||
|
NonEmptyStringBase,
|
||||||
|
Field(description="A non-empty string", min_length=1),
|
||||||
|
]
|
||||||
|
non_empty_string_adapter = TypeAdapter(NonEmptyString)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_non_empty_string(s: str) -> NonEmptyString:
|
||||||
|
return non_empty_string_adapter.validate_python(s)
|
||||||
|
|
||||||
|
|
||||||
|
def try_parse_non_empty_string(s: str) -> NonEmptyString | None:
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
return parse_non_empty_string(s)
|
||||||
@@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
|||||||
from fastapi_pagination import Page
|
from fastapi_pagination import Page
|
||||||
from fastapi_pagination.ext.databases import apaginate
|
from fastapi_pagination.ext.databases import apaginate
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_serializer
|
from pydantic import BaseModel, Field, constr, field_serializer
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_database
|
from reflector.db import get_database
|
||||||
@@ -19,10 +19,10 @@ from reflector.db.search import (
|
|||||||
SearchOffsetBase,
|
SearchOffsetBase,
|
||||||
SearchParameters,
|
SearchParameters,
|
||||||
SearchQuery,
|
SearchQuery,
|
||||||
SearchQueryBase,
|
|
||||||
SearchResult,
|
SearchResult,
|
||||||
SearchTotal,
|
SearchTotal,
|
||||||
search_controller,
|
search_controller,
|
||||||
|
search_query_adapter,
|
||||||
)
|
)
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
SourceKind,
|
SourceKind,
|
||||||
@@ -114,7 +114,19 @@ class DeletionStatus(BaseModel):
|
|||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")]
|
SearchQueryParamBase = constr(min_length=0, strip_whitespace=True)
|
||||||
|
SearchQueryParam = Annotated[
|
||||||
|
SearchQueryParamBase, Query(description="Search query text")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# http and api standards accept "q="; we would like to handle it as the absence of query, not as "empty string query"
|
||||||
|
def parse_search_query_param(q: SearchQueryParam) -> SearchQuery | None:
|
||||||
|
if q == "":
|
||||||
|
return None
|
||||||
|
return search_query_adapter.validate_python(q)
|
||||||
|
|
||||||
|
|
||||||
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
|
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
|
||||||
SearchOffsetParam = Annotated[
|
SearchOffsetParam = Annotated[
|
||||||
SearchOffsetBase, Query(description="Number of results to skip")
|
SearchOffsetBase, Query(description="Number of results to skip")
|
||||||
@@ -124,7 +136,7 @@ SearchOffsetParam = Annotated[
|
|||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
results: list[SearchResult]
|
results: list[SearchResult]
|
||||||
total: SearchTotal
|
total: SearchTotal
|
||||||
query: SearchQuery
|
query: SearchQuery | None = None
|
||||||
limit: SearchLimit
|
limit: SearchLimit
|
||||||
offset: SearchOffset
|
offset: SearchOffset
|
||||||
|
|
||||||
@@ -174,7 +186,7 @@ async def transcripts_search(
|
|||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
search_params = SearchParameters(
|
search_params = SearchParameters(
|
||||||
query_text=q,
|
query_text=parse_search_query_param(q),
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ async def test_search_postgresql_only():
|
|||||||
assert results == []
|
assert results == []
|
||||||
assert total == 0
|
assert total == 0
|
||||||
|
|
||||||
params_empty = SearchParameters(query_text="")
|
params_empty = SearchParameters(query_text=None)
|
||||||
results_empty, total_empty = await search_controller.search_transcripts(
|
results_empty, total_empty = await search_controller.search_transcripts(
|
||||||
params_empty
|
params_empty
|
||||||
)
|
)
|
||||||
@@ -34,7 +34,7 @@ async def test_search_postgresql_only():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_with_empty_query():
|
async def test_search_with_empty_query():
|
||||||
"""Test that empty query returns all transcripts."""
|
"""Test that empty query returns all transcripts."""
|
||||||
params = SearchParameters(query_text="")
|
params = SearchParameters(query_text=None)
|
||||||
results, total = await search_controller.search_transcripts(params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
assert isinstance(results, list)
|
assert isinstance(results, list)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Unit tests for search snippet generation."""
|
"""Unit tests for search snippet generation."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from reflector.db.search import (
|
from reflector.db.search import (
|
||||||
SnippetCandidate,
|
SnippetCandidate,
|
||||||
SnippetGenerator,
|
SnippetGenerator,
|
||||||
@@ -512,11 +514,9 @@ data visualization and data storage"""
|
|||||||
)
|
)
|
||||||
assert webvtt_count == 3
|
assert webvtt_count == 3
|
||||||
|
|
||||||
snippets_empty, count_empty = SnippetGenerator.combine_sources(
|
# combine_sources requires at least one source to be present
|
||||||
None, None, "data", max_total=3
|
with pytest.raises(AssertionError, match="At least one source must be present"):
|
||||||
)
|
SnippetGenerator.combine_sources(None, None, "data", max_total=3)
|
||||||
assert snippets_empty == []
|
|
||||||
assert count_empty == 0
|
|
||||||
|
|
||||||
def test_edge_cases(self):
|
def test_edge_cases(self):
|
||||||
"""Test edge cases for the pure functions."""
|
"""Test edge cases for the pure functions."""
|
||||||
|
|||||||
Reference in New Issue
Block a user