mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
feat: search frontend (#551)
* feat: better highlight * feat(search): add long_summary to search vector for improved search results - Update search vector to include long_summary with weight B (between title A and webvtt C) - Modify SearchController to fetch long_summary and prioritize its snippets - Generate snippets from long_summary first (max 2), then from webvtt for remaining slots - Add comprehensive tests for long_summary search functionality - Create migration to update search_vector_en column in PostgreSQL This improves search quality by including summarized content which often contains key topics and themes that may not be explicitly mentioned in the transcript. * fix: address code review feedback for search enhancements - Fix test file inconsistencies by removing references to non-existent model fields - Comment out tests for unimplemented features (room_ids, status filters, date ranges) - Update tests to only use currently available fields (room_id singular, no room_name/processing_status) - Mark future functionality tests with @pytest.mark.skip - Make snippet counts configurable - Add LONG_SUMMARY_MAX_SNIPPETS constant (default: 2) - Replace hardcoded value with configurable constant - Improve error handling consistency in WebVTT parsing - Use different log levels for different error types (debug for malformed, warning for decode, error for unexpected) - Add catch-all exception handler for unexpected errors - Include stack trace for critical errors All existing tests pass with these changes. * fix: correct datetime test to include required duration field * feat: better highlight * feat: search room names * feat: acknowledge deleted room * feat: search filters fix and rank removal * chore: minor refactoring * feat: better matches frontend * chore: self-review (vibe) * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * remove swc (vibe) * search url query sync (vibe) * search url query sync (vibe) * better casts and cap while * PR review + simplify frontend hook * pr: remove search db timeouts * cleanup tests * tests cleanup * frontend cleanup * index declarations * refactor frontend (self-review) * fix search pagination * clear "x" for search input * pagination max pages fix * chore: cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * lockfile * pr review
This commit is contained in:
3
server/.gitignore
vendored
3
server/.gitignore
vendored
@@ -176,7 +176,8 @@ artefacts/
|
||||
audio_*.wav
|
||||
|
||||
# ignore local database
|
||||
reflector.sqlite3
|
||||
*.sqlite3
|
||||
*.db
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""add_long_summary_to_search_vector
|
||||
|
||||
Revision ID: 0ab2d7ffaa16
|
||||
Revises: b1c33bd09963
|
||||
Create Date: 2025-08-15 13:27:52.680211
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ab2d7ffaa16"
|
||||
down_revision: Union[str, None] = "b1c33bd09963"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the search vector column with long_summary included
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the updated search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the original search vector column without long_summary
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_search_optimization_indexes
|
||||
|
||||
Revision ID: b1c33bd09963
|
||||
Revises: 9f5c78d352d6
|
||||
Create Date: 2025-08-14 17:26:02.117408
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b1c33bd09963"
|
||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add indexes for actual search filtering patterns used in frontend
|
||||
# Based on /browse page filters: room_id and source_kind
|
||||
|
||||
# Index for room_id + created_at (for room-specific searches with date ordering)
|
||||
op.create_index(
|
||||
"idx_transcript_room_id_created_at",
|
||||
"transcript",
|
||||
["room_id", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
# Index for source_kind alone (actively used filter in frontend)
|
||||
op.create_index(
|
||||
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the indexes in reverse order
|
||||
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
|
||||
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)
|
||||
@@ -1,24 +1,37 @@
|
||||
"""Search functionality for transcripts and other entities."""
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict
|
||||
from typing import Annotated, Any, Dict, Iterator
|
||||
|
||||
import sqlalchemy
|
||||
import webvtt
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
from fastapi import HTTPException
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
ValidationError,
|
||||
constr,
|
||||
field_serializer,
|
||||
)
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.logger import logger
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 20
|
||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
||||
DEFAULT_MAX_SNIPPETS = 3
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||
|
||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
|
||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||
@@ -32,6 +45,82 @@ SearchTotal = Annotated[
|
||||
SearchTotalBase, Field(description="Total number of search results")
|
||||
]
|
||||
|
||||
WEBVTT_SPEC_HEADER = "WEBVTT\n\n"
|
||||
|
||||
WebVTTContent = Annotated[
|
||||
str,
|
||||
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
|
||||
]
|
||||
|
||||
|
||||
class WebVTTProcessor:
|
||||
"""Stateless processor for WebVTT content operations."""
|
||||
|
||||
@staticmethod
|
||||
def parse(raw_content: str) -> WebVTTContent:
|
||||
"""Parse WebVTT content and return it as a string."""
|
||||
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
|
||||
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
|
||||
return raw_content
|
||||
|
||||
@staticmethod
|
||||
def extract_text(webvtt_content: WebVTTContent) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except webvtt.errors.MalformedFileError as e:
|
||||
logger.warning(f"Malformed WebVTT content: {e}")
|
||||
return ""
|
||||
except (UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to decode WebVTT content: {e}")
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def generate_snippets(
|
||||
webvtt_content: WebVTTContent,
|
||||
query: str,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from WebVTT content."""
|
||||
return SnippetGenerator.generate(
|
||||
WebVTTProcessor.extract_text(webvtt_content),
|
||||
query,
|
||||
max_snippets=max_snippets,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnippetCandidate:
|
||||
"""Represents a candidate snippet with its position."""
|
||||
|
||||
_text: str
|
||||
start: NonNegativeInt
|
||||
_original_text_length: int
|
||||
|
||||
@property
|
||||
def end(self) -> NonNegativeInt:
|
||||
"""Calculate end position from start and raw text length."""
|
||||
return self.start + len(self._text)
|
||||
|
||||
def text(self) -> str:
|
||||
"""Get display text with ellipses added if needed."""
|
||||
result = self._text.strip()
|
||||
if self.start > 0:
|
||||
result = "..." + result
|
||||
if self.end < self._original_text_length:
|
||||
result = result + "..."
|
||||
return result
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
"""Validated search parameters for full-text search."""
|
||||
@@ -41,6 +130,7 @@ class SearchParameters(BaseModel):
|
||||
offset: SearchOffset = 0
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
source_kind: SourceKind | None = None
|
||||
|
||||
|
||||
class SearchResultDB(BaseModel):
|
||||
@@ -64,13 +154,18 @@ class SearchResult(BaseModel):
|
||||
title: str | None = None
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
room_name: str | None = None
|
||||
source_kind: SourceKind
|
||||
created_at: datetime
|
||||
status: str = Field(..., min_length=1)
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
||||
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
||||
search_snippets: list[str] = Field(
|
||||
description="Text snippets around search matches"
|
||||
)
|
||||
total_match_count: NonNegativeInt = Field(
|
||||
default=0, description="Total number of matches found in the transcript"
|
||||
)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -79,84 +174,153 @@ class SearchResult(BaseModel):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
class SnippetGenerator:
|
||||
"""Stateless generator for text snippets and match operations."""
|
||||
|
||||
@staticmethod
|
||||
def _extract_webvtt_text(webvtt_content: str) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
if not webvtt_content:
|
||||
return ""
|
||||
def find_all_matches(text: str, query: str) -> Iterator[int]:
|
||||
"""Generate all match positions for a query in text."""
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in find_all_matches")
|
||||
return
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in find_all_matches")
|
||||
return
|
||||
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
||||
return ""
|
||||
text_lower = text.lower()
|
||||
query_lower = query.lower()
|
||||
start = 0
|
||||
prev_start = start
|
||||
while (pos := text_lower.find(query_lower, start)) != -1:
|
||||
yield pos
|
||||
start = pos + len(query_lower)
|
||||
if start <= prev_start:
|
||||
raise ValueError("panic! find_all_matches is not incremental")
|
||||
prev_start = start
|
||||
|
||||
@staticmethod
|
||||
def _generate_snippets(
|
||||
def count_matches(text: str, query: str) -> NonNegativeInt:
|
||||
"""Count total number of matches for a query in text."""
|
||||
ZERO = NonNegativeInt(0)
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in count_matches")
|
||||
return ZERO
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in count_matches")
|
||||
return ZERO
|
||||
return NonNegativeInt(
|
||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_snippet(
|
||||
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
|
||||
) -> SnippetCandidate:
|
||||
"""Create a snippet from a match position."""
|
||||
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
|
||||
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
|
||||
|
||||
snippet_text = text[snippet_start:snippet_end]
|
||||
|
||||
return SnippetCandidate(
|
||||
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_non_overlapping(
|
||||
candidates: Iterator[SnippetCandidate],
|
||||
) -> Iterator[str]:
|
||||
"""Filter out overlapping snippets and return only display text."""
|
||||
last_end = 0
|
||||
for candidate in candidates:
|
||||
display_text = candidate.text()
|
||||
# it means that next overlapping snippets simply don't get included
|
||||
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
|
||||
if candidate.start >= last_end and display_text:
|
||||
yield display_text
|
||||
last_end = candidate.end
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
text: str,
|
||||
q: SearchQuery,
|
||||
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
||||
query: str,
|
||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate multiple snippets around all occurrences of search term."""
|
||||
if not text or not q:
|
||||
"""Generate snippets from text."""
|
||||
if not text or not query:
|
||||
logger.warning("Empty text or query for generate_snippets")
|
||||
return []
|
||||
|
||||
snippets = []
|
||||
lower_text = text.lower()
|
||||
search_lower = q.lower()
|
||||
candidates = (
|
||||
SnippetGenerator.create_snippet(text, pos, max_length)
|
||||
for pos in SnippetGenerator.find_all_matches(text, query)
|
||||
)
|
||||
filtered = SnippetGenerator.filter_non_overlapping(candidates)
|
||||
snippets = list(itertools.islice(filtered, max_snippets))
|
||||
|
||||
last_snippet_end = 0
|
||||
start_pos = 0
|
||||
|
||||
while len(snippets) < max_snippets:
|
||||
match_pos = lower_text.find(search_lower, start_pos)
|
||||
|
||||
if match_pos == -1:
|
||||
if not snippets and search_lower.split():
|
||||
first_word = search_lower.split()[0]
|
||||
match_pos = lower_text.find(first_word, start_pos)
|
||||
if match_pos == -1:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
||||
snippet_end = min(
|
||||
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if snippet_start < last_snippet_end:
|
||||
start_pos = match_pos + len(search_lower)
|
||||
continue
|
||||
|
||||
snippet = text[snippet_start:snippet_end]
|
||||
|
||||
if snippet_start > 0:
|
||||
snippet = "..." + snippet
|
||||
if snippet_end < len(text):
|
||||
snippet = snippet + "..."
|
||||
|
||||
snippet = snippet.strip()
|
||||
|
||||
if snippet:
|
||||
snippets.append(snippet)
|
||||
last_snippet_end = snippet_end
|
||||
|
||||
start_pos = match_pos + len(search_lower)
|
||||
if start_pos >= len(text):
|
||||
break
|
||||
# Fallback to first word search if no full matches
|
||||
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
|
||||
if not snippets and " " in query:
|
||||
first_word = query.split()[0]
|
||||
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
|
||||
|
||||
return snippets
|
||||
|
||||
@staticmethod
|
||||
def from_summary(
|
||||
summary: str,
|
||||
query: str,
|
||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from summary text."""
|
||||
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
|
||||
|
||||
@staticmethod
|
||||
def combine_sources(
|
||||
summary: str | None,
|
||||
webvtt: WebVTTContent | None,
|
||||
query: str,
|
||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> tuple[list[str], NonNegativeInt]:
|
||||
"""Combine snippets from multiple sources and return total match count.
|
||||
|
||||
Returns (snippets, total_match_count) tuple.
|
||||
|
||||
snippets can be empty for real in case of e.g. title match
|
||||
"""
|
||||
webvtt_matches = 0
|
||||
summary_matches = 0
|
||||
|
||||
if webvtt:
|
||||
webvtt_text = WebVTTProcessor.extract_text(webvtt)
|
||||
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
|
||||
|
||||
if summary:
|
||||
summary_matches = SnippetGenerator.count_matches(summary, query)
|
||||
|
||||
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
|
||||
|
||||
summary_snippets = (
|
||||
SnippetGenerator.from_summary(summary, query) if summary else []
|
||||
)
|
||||
|
||||
if len(summary_snippets) >= max_total:
|
||||
return summary_snippets[:max_total], total_matches
|
||||
|
||||
remaining = max_total - len(summary_snippets)
|
||||
webvtt_snippets = (
|
||||
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
|
||||
if webvtt
|
||||
else []
|
||||
)
|
||||
|
||||
return summary_snippets + webvtt_snippets, total_matches
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
|
||||
@classmethod
|
||||
async def search_transcripts(
|
||||
cls, params: SearchParameters
|
||||
@@ -172,39 +336,64 @@ class SearchController:
|
||||
)
|
||||
return [], 0
|
||||
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
base_columns = [
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
transcripts.c.long_summary,
|
||||
sqlalchemy.case(
|
||||
(
|
||||
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
|
||||
"Deleted Room",
|
||||
),
|
||||
else_=rooms.c.name,
|
||||
).label("room_name"),
|
||||
]
|
||||
|
||||
if params.query_text:
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
rank_column = sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank")
|
||||
else:
|
||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||
|
||||
columns = base_columns + [rank_column]
|
||||
base_query = sqlalchemy.select(columns).select_from(
|
||||
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||
)
|
||||
|
||||
base_query = sqlalchemy.select(
|
||||
[
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank"),
|
||||
]
|
||||
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
||||
if params.query_text:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||
)
|
||||
|
||||
if params.user_id:
|
||||
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||
if params.room_id:
|
||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||
if params.source_kind:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.source_kind == params.source_kind
|
||||
)
|
||||
|
||||
if params.query_text:
|
||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||
else:
|
||||
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||
|
||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||
|
||||
query = (
|
||||
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
||||
.limit(params.limit)
|
||||
.offset(params.offset)
|
||||
)
|
||||
rs = await get_database().fetch_all(query)
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
@@ -214,18 +403,40 @@ class SearchController:
|
||||
|
||||
def _process_result(r) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
webvtt: str | None = r_dict.pop("webvtt", None)
|
||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||
if webvtt_raw:
|
||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||
else:
|
||||
webvtt = None
|
||||
long_summary: str | None = r_dict.pop("long_summary", None)
|
||||
room_name: str | None = r_dict.pop("room_name", None)
|
||||
db_result = SearchResultDB.model_validate(r_dict)
|
||||
|
||||
snippets = []
|
||||
if webvtt:
|
||||
plain_text = cls._extract_webvtt_text(webvtt)
|
||||
snippets = cls._generate_snippets(plain_text, params.query_text)
|
||||
snippets, total_match_count = SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||
)
|
||||
|
||||
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
||||
return SearchResult(
|
||||
**db_result.model_dump(),
|
||||
room_name=room_name,
|
||||
search_snippets=snippets,
|
||||
total_match_count=total_match_count,
|
||||
)
|
||||
|
||||
try:
|
||||
results = [_process_result(r) for r in rs]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid search result data: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal search result data consistency error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
results = [_process_result(r) for r in rs]
|
||||
return results, total
|
||||
|
||||
|
||||
search_controller = SearchController()
|
||||
webvtt_processor = WebVTTProcessor()
|
||||
snippet_generator = SnippetGenerator()
|
||||
|
||||
@@ -88,6 +88,8 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
|
||||
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||
)
|
||||
|
||||
# Add PostgreSQL-specific full-text search column
|
||||
@@ -99,7 +101,8 @@ if is_postgresql():
|
||||
TSVECTOR,
|
||||
sqlalchemy.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -160,6 +160,7 @@ async def transcripts_search(
|
||||
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
|
||||
offset: SearchOffsetParam = 0,
|
||||
room_id: Optional[str] = None,
|
||||
source_kind: Optional[SourceKind] = None,
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||
] = None,
|
||||
@@ -173,7 +174,12 @@ async def transcripts_search(
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
search_params = SearchParameters(
|
||||
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
|
||||
query_text=q,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
source_kind=source_kind,
|
||||
)
|
||||
|
||||
results, total = await search_controller.search_transcripts(search_params)
|
||||
|
||||
@@ -2,13 +2,18 @@
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
from reflector.db.search import (
|
||||
SearchController,
|
||||
SearchParameters,
|
||||
SearchResult,
|
||||
search_controller,
|
||||
)
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -18,39 +23,135 @@ async def test_search_postgresql_only():
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" ")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
params_empty = SearchParameters(query_text="")
|
||||
results_empty, total_empty = await search_controller.search_transcripts(
|
||||
params_empty
|
||||
)
|
||||
assert isinstance(results_empty, list)
|
||||
assert isinstance(total_empty, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_input_validation():
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
async def test_search_with_empty_query():
|
||||
"""Test that empty query returns all transcripts."""
|
||||
params = SearchParameters(query_text="")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
if len(results) > 1:
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i].created_at >= results[i + 1].created_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_title_only_match():
|
||||
"""Test that transcripts with title-only matches return empty snippets."""
|
||||
test_id = "test-empty-9b3f2a8d"
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" \t\n ")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Empty Transcript",
|
||||
"title": "Empty Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 0.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": None,
|
||||
"long_summary": None,
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": None,
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="empty")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = next((r for r in results if r.id == test_id), None)
|
||||
assert found is not None, "Should find transcript by title match"
|
||||
assert found.search_snippets == []
|
||||
assert found.total_match_count == 0
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_long_summary():
|
||||
"""Test that long_summary content is searchable."""
|
||||
test_id = "test-long-summary-8a9f3c2d"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Summary",
|
||||
"title": "Regular Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Brief overview",
|
||||
"long_summary": "Detailed discussion about quantum computing applications and blockchain technology integration",
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Basic meeting content without special keywords.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="quantum computing")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
assert "quantum computing" in test_result.search_snippets[0].lower()
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgresql_search_with_data():
|
||||
# collision is improbable
|
||||
test_id = "test-search-e2e-7f3a9b2c"
|
||||
|
||||
try:
|
||||
@@ -94,28 +195,24 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Test 1: Search for a word in title
|
||||
params = SearchParameters(query_text="planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by title word"
|
||||
|
||||
# Test 2: Search for a word in webvtt content
|
||||
params = SearchParameters(query_text="tsvector")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by webvtt content"
|
||||
|
||||
# Test 3: Search with multiple words
|
||||
params = SearchParameters(query_text="engineering planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by multiple words"
|
||||
|
||||
# Test 4: Verify SearchResult structure
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
if test_result:
|
||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||
@@ -123,14 +220,12 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
assert test_result.duration == 1800.0
|
||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||
|
||||
# Test 5: Search with OR operator
|
||||
params = SearchParameters(query_text="tsvector OR nosuchword")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript with OR query"
|
||||
|
||||
# Test 6: Quoted phrase search
|
||||
params = SearchParameters(query_text='"full-text search"')
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
@@ -142,3 +237,240 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_params():
|
||||
"""Create sample search parameters for testing."""
|
||||
return SearchParameters(
|
||||
query_text="test query",
|
||||
limit=20,
|
||||
offset=0,
|
||||
user_id="test-user",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_result():
|
||||
"""Create a mock database result."""
|
||||
return {
|
||||
"id": "test-transcript-id",
|
||||
"title": "Test Transcript",
|
||||
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
"duration": 3600.0,
|
||||
"status": "completed",
|
||||
"user_id": "test-user",
|
||||
"room_id": "room1",
|
||||
"source_kind": SourceKind.LIVE,
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\nThis is a test transcript",
|
||||
"rank": 0.95,
|
||||
}
|
||||
|
||||
|
||||
class TestSearchParameters:
|
||||
"""Test SearchParameters model validation and functionality."""
|
||||
|
||||
def test_search_parameters_with_available_filters(self):
|
||||
"""Test creating SearchParameters with currently available filter options."""
|
||||
params = SearchParameters(
|
||||
query_text="search term",
|
||||
limit=50,
|
||||
offset=10,
|
||||
user_id="user123",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
assert params.query_text == "search term"
|
||||
assert params.limit == 50
|
||||
assert params.offset == 10
|
||||
assert params.user_id == "user123"
|
||||
assert params.room_id == "room1"
|
||||
|
||||
def test_search_parameters_defaults(self):
|
||||
"""Test SearchParameters with default values."""
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
assert params.query_text == "test"
|
||||
assert params.limit == 20
|
||||
assert params.offset == 0
|
||||
assert params.user_id is None
|
||||
assert params.room_id is None
|
||||
|
||||
|
||||
class TestSearchControllerFilters:
|
||||
"""Test SearchController functionality with various filters."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_source_kind_filter(self):
|
||||
"""Test search filtering by source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_single_room_id(self):
|
||||
"""Test search filtering by single room ID (currently supported)."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(
|
||||
query_text="test",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_result_includes_available_fields(self, mock_db_result):
|
||||
"""Test that search results include available fields like source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
|
||||
class MockRow:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
self._mapping = data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.items())
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
mock_row = MockRow(mock_db_result)
|
||||
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
|
||||
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert total == 1
|
||||
assert len(results) == 1
|
||||
|
||||
result = results[0]
|
||||
assert isinstance(result, SearchResult)
|
||||
assert result.id == "test-transcript-id"
|
||||
assert result.title == "Test Transcript"
|
||||
assert result.rank == 0.95
|
||||
|
||||
|
||||
class TestSearchEndpointParsing:
|
||||
"""Test parameter parsing in the search endpoint."""
|
||||
|
||||
def test_parse_comma_separated_room_ids(self):
|
||||
"""Test parsing comma-separated room IDs."""
|
||||
room_ids_str = "room1,room2,room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1, room2 , room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1,,room3,"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room3"]
|
||||
|
||||
def test_parse_source_kind(self):
|
||||
"""Test parsing source_kind values."""
|
||||
for kind_str in ["live", "file", "room"]:
|
||||
parsed = SourceKind(kind_str)
|
||||
assert parsed == SourceKind(kind_str)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SourceKind("invalid_kind")
|
||||
|
||||
|
||||
class TestSearchResultModel:
|
||||
"""Test SearchResult model and serialization."""
|
||||
|
||||
def test_search_result_with_available_fields(self):
|
||||
"""Test SearchResult model with currently available fields populated."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
title="Test Title",
|
||||
user_id="user-123",
|
||||
room_id="room-456",
|
||||
source_kind=SourceKind.ROOM,
|
||||
created_at=datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.85,
|
||||
duration=1800.5,
|
||||
search_snippets=["snippet 1", "snippet 2"],
|
||||
)
|
||||
|
||||
assert result.id == "test-id"
|
||||
assert result.title == "Test Title"
|
||||
assert result.user_id == "user-123"
|
||||
assert result.room_id == "room-456"
|
||||
assert result.status == "completed"
|
||||
assert result.rank == 0.85
|
||||
assert result.duration == 1800.5
|
||||
assert len(result.search_snippets) == 2
|
||||
|
||||
def test_search_result_with_optional_fields_none(self):
|
||||
"""Test SearchResult model with optional fields as None."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.FILE,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
status="processing",
|
||||
rank=0.5,
|
||||
search_snippets=[],
|
||||
title=None,
|
||||
user_id=None,
|
||||
room_id=None,
|
||||
duration=None,
|
||||
)
|
||||
|
||||
assert result.title is None
|
||||
assert result.user_id is None
|
||||
assert result.room_id is None
|
||||
assert result.duration is None
|
||||
|
||||
def test_search_result_datetime_field(self):
|
||||
"""Test that SearchResult accepts datetime field."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.LIVE,
|
||||
created_at=datetime(2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.9,
|
||||
duration=None,
|
||||
search_snippets=[],
|
||||
)
|
||||
|
||||
assert result.created_at == datetime(
|
||||
2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
164
server/tests/test_search_long_summary.py
Normal file
164
server/tests/test_search_long_summary.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for long_summary in search functionality."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_snippet_prioritization():
|
||||
"""Test that snippets from long_summary are prioritized over webvtt content."""
|
||||
test_id = "test-snippet-priority-3f9a2b8c"
|
||||
|
||||
try:
|
||||
# Clean up any existing test data
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Snippet Priority",
|
||||
"title": "Meeting About Projects",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Project discussion",
|
||||
"long_summary": (
|
||||
"The team discussed advanced robotics applications including "
|
||||
"autonomous navigation systems and sensor fusion techniques. "
|
||||
"Robotics development will focus on real-time processing."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
We talked about many different topics today.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
The robotics project is making good progress.
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
We need to consider various implementation approaches.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for "robotics" which appears in both long_summary and webvtt
|
||||
params = SearchParameters(query_text="robotics")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result, "Should find the test transcript"
|
||||
|
||||
snippets = test_result.search_snippets
|
||||
assert len(snippets) > 0, "Should have at least one snippet"
|
||||
|
||||
# The first snippets should be from long_summary (more detailed content)
|
||||
first_snippet = snippets[0].lower()
|
||||
assert (
|
||||
"advanced robotics" in first_snippet or "autonomous" in first_snippet
|
||||
), f"First snippet should be from long_summary with detailed content. Got: {snippets[0]}"
|
||||
|
||||
# With max 3 snippets, we should get both from long_summary and webvtt
|
||||
assert len(snippets) <= 3, "Should respect max snippets limit"
|
||||
|
||||
# All snippets should contain the search term
|
||||
for snippet in snippets:
|
||||
assert (
|
||||
"robotics" in snippet.lower()
|
||||
), f"Snippet should contain search term: {snippet}"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_only_search():
|
||||
"""Test searching for content that only exists in long_summary."""
|
||||
test_id = "test-long-only-8b3c9f2a"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Only",
|
||||
"title": "Standard Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Team sync",
|
||||
"long_summary": (
|
||||
"Detailed analysis of cryptocurrency market trends and "
|
||||
"decentralized finance protocols. Discussion included "
|
||||
"yield farming strategies and liquidity pool mechanics."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Team meeting about general project updates.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
Discussion of timeline and deliverables.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for terms only in long_summary
|
||||
params = SearchParameters(query_text="cryptocurrency")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary-only content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
|
||||
# Verify the snippet is about cryptocurrency
|
||||
snippet = test_result.search_snippets[0].lower()
|
||||
assert "cryptocurrency" in snippet, "Snippet should contain the search term"
|
||||
|
||||
# Search for "yield farming" - a more specific term
|
||||
params2 = SearchParameters(query_text="yield farming")
|
||||
results2, total2 = await search_controller.search_transcripts(params2)
|
||||
|
||||
found2 = any(r.id == test_id for r in results2)
|
||||
assert found2, "Should find transcript by specific long_summary phrase"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
@@ -1,6 +1,10 @@
|
||||
"""Unit tests for search snippet generation."""
|
||||
|
||||
from reflector.db.search import SearchController
|
||||
from reflector.db.search import (
|
||||
SnippetCandidate,
|
||||
SnippetGenerator,
|
||||
WebVTTProcessor,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractWebVTT:
|
||||
@@ -16,7 +20,7 @@ class TestExtractWebVTT:
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
||||
"""
|
||||
result = SearchController._extract_webvtt_text(webvtt)
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
assert "Hello world, this is a test" in result
|
||||
assert "Indeed it is a test" in result
|
||||
assert "<v Speaker" not in result
|
||||
@@ -25,12 +29,11 @@ class TestExtractWebVTT:
|
||||
|
||||
def test_extract_empty_webvtt(self):
|
||||
"""Test empty WebVTT returns empty string."""
|
||||
assert SearchController._extract_webvtt_text("") == ""
|
||||
assert SearchController._extract_webvtt_text(None) == ""
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
|
||||
def test_extract_malformed_webvtt(self):
|
||||
"""Test malformed WebVTT returns empty string."""
|
||||
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
|
||||
result = WebVTTProcessor.extract_text("Not a valid WebVTT")
|
||||
assert result == ""
|
||||
|
||||
|
||||
@@ -39,8 +42,7 @@ class TestGenerateSnippets:
|
||||
|
||||
def test_multiple_matches(self):
|
||||
"""Test finding multiple occurrences of search term in long text."""
|
||||
# Create text with Python mentions far apart to get separate snippets
|
||||
separator = " This is filler text. " * 20 # ~400 chars of padding
|
||||
separator = " This is filler text. " * 20
|
||||
text = (
|
||||
"Python is great for machine learning."
|
||||
+ separator
|
||||
@@ -51,18 +53,16 @@ class TestGenerateSnippets:
|
||||
+ "The Python community is very supportive."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "Python")
|
||||
# With enough separation, we should get multiple snippets
|
||||
assert len(snippets) >= 2 # At least 2 distinct snippets
|
||||
snippets = SnippetGenerator.generate(text, "Python")
|
||||
assert len(snippets) >= 2
|
||||
|
||||
# Each snippet should contain "Python"
|
||||
for snippet in snippets:
|
||||
assert "python" in snippet.lower()
|
||||
|
||||
def test_single_match(self):
|
||||
"""Test single occurrence returns one snippet."""
|
||||
text = "This document discusses artificial intelligence and its applications."
|
||||
snippets = SearchController._generate_snippets(text, "artificial intelligence")
|
||||
snippets = SnippetGenerator.generate(text, "artificial intelligence")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "artificial intelligence" in snippets[0].lower()
|
||||
@@ -70,24 +70,22 @@ class TestGenerateSnippets:
|
||||
def test_no_matches(self):
|
||||
"""Test no matches returns empty list."""
|
||||
text = "This is some random text without the search term."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
assert snippets == []
|
||||
|
||||
def test_case_insensitive_search(self):
|
||||
"""Test search is case insensitive."""
|
||||
# Add enough text between matches to get separate snippets
|
||||
text = (
|
||||
"MACHINE LEARNING is important for modern applications. "
|
||||
+ "It requires lots of data and computational resources. " * 5 # Padding
|
||||
+ "It requires lots of data and computational resources. " * 5
|
||||
+ "Machine Learning rocks and transforms industries. "
|
||||
+ "Deep learning is a subset of it. " * 5 # More padding
|
||||
+ "Deep learning is a subset of it. " * 5
|
||||
+ "Finally, machine learning will shape our future."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
# Should find at least 2 (might be 3 if text is long enough)
|
||||
assert len(snippets) >= 2
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
@@ -95,61 +93,55 @@ class TestGenerateSnippets:
|
||||
def test_partial_match_fallback(self):
|
||||
"""Test fallback to first word when exact phrase not found."""
|
||||
text = "We use machine intelligence for processing."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
# Should fall back to finding "machine"
|
||||
assert len(snippets) == 1
|
||||
assert "machine" in snippets[0].lower()
|
||||
|
||||
def test_snippet_ellipsis(self):
|
||||
"""Test ellipsis added for truncated snippets."""
|
||||
# Long text where match is in the middle
|
||||
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
||||
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
|
||||
snippets = SnippetGenerator.generate(text, "TARGET_WORD")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "..." in snippets[0] # Should have ellipsis
|
||||
assert "..." in snippets[0]
|
||||
assert "TARGET_WORD" in snippets[0]
|
||||
|
||||
def test_overlapping_snippets_deduplicated(self):
|
||||
"""Test overlapping matches don't create duplicate snippets."""
|
||||
text = "test test test word" * 10 # Repeated pattern
|
||||
snippets = SearchController._generate_snippets(text, "test")
|
||||
text = "test test test word" * 10
|
||||
snippets = SnippetGenerator.generate(text, "test")
|
||||
|
||||
# Should get unique snippets, not duplicates
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) == len(set(snippets)) # All unique
|
||||
assert len(snippets) == len(set(snippets))
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test empty text or search term returns empty list."""
|
||||
assert SearchController._generate_snippets("", "search") == []
|
||||
assert SearchController._generate_snippets("text", "") == []
|
||||
assert SearchController._generate_snippets("", "") == []
|
||||
assert SnippetGenerator.generate("", "search") == []
|
||||
assert SnippetGenerator.generate("text", "") == []
|
||||
assert SnippetGenerator.generate("", "") == []
|
||||
|
||||
def test_max_snippets_limit(self):
|
||||
"""Test respects max_snippets parameter."""
|
||||
# Create text with well-separated occurrences
|
||||
separator = " filler " * 50 # Ensure snippets don't overlap
|
||||
text = ("Python is amazing" + separator) * 10 # 10 occurrences
|
||||
separator = " filler " * 50
|
||||
text = ("Python is amazing" + separator) * 10
|
||||
|
||||
# Test with different limits
|
||||
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
|
||||
snippets_1 = SnippetGenerator.generate(text, "Python", max_snippets=1)
|
||||
assert len(snippets_1) == 1
|
||||
|
||||
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
|
||||
snippets_2 = SnippetGenerator.generate(text, "Python", max_snippets=2)
|
||||
assert len(snippets_2) == 2
|
||||
|
||||
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
|
||||
snippets_5 = SnippetGenerator.generate(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5
|
||||
|
||||
def test_snippet_length(self):
|
||||
"""Test snippet length is reasonable."""
|
||||
text = "word " * 200 # Long text
|
||||
snippets = SearchController._generate_snippets(text, "word")
|
||||
text = "word " * 200
|
||||
snippets = SnippetGenerator.generate(text, "word")
|
||||
|
||||
for snippet in snippets:
|
||||
# Default max_length is 150 + some context
|
||||
assert len(snippet) <= 200 # Some buffer for ellipsis
|
||||
assert len(snippet) <= 200
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
@@ -157,7 +149,6 @@ class TestFullPipeline:
|
||||
|
||||
def test_webvtt_to_snippets_integration(self):
|
||||
"""Test full pipeline from WebVTT to search snippets."""
|
||||
# Create WebVTT with well-separated content for multiple snippets
|
||||
webvtt = (
|
||||
"""WEBVTT
|
||||
|
||||
@@ -182,17 +173,362 @@ class TestFullPipeline:
|
||||
"""
|
||||
)
|
||||
|
||||
# Extract and generate snippets
|
||||
plain_text = SearchController._extract_webvtt_text(webvtt)
|
||||
snippets = SearchController._generate_snippets(plain_text, "machine learning")
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt)
|
||||
snippets = SnippetGenerator.generate(plain_text, "machine learning")
|
||||
|
||||
# Should find at least 2 snippets (text might still be close together)
|
||||
assert len(snippets) >= 1 # At minimum one snippet containing matches
|
||||
assert len(snippets) <= 3 # At most 3 by default
|
||||
assert len(snippets) >= 1
|
||||
assert len(snippets) <= 3
|
||||
|
||||
# No WebVTT artifacts in snippets
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
assert "<v Speaker" not in snippet
|
||||
assert "00:00" not in snippet
|
||||
assert "-->" not in snippet
|
||||
|
||||
|
||||
class TestMultiWordQueryBehavior:
|
||||
"""Tests for multi-word query behavior and exact phrase matching."""
|
||||
|
||||
def test_multi_word_query_snippet_behavior(self):
|
||||
"""Test that multi-word queries generate snippets based on exact phrase matching."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
user_snippets = SnippetGenerator.generate(sample_text, "user")
|
||||
assert len(user_snippets) == 2, "Should find 2 snippets for 'user'"
|
||||
|
||||
jordan_snippets = SnippetGenerator.generate(sample_text, "jordan")
|
||||
assert len(jordan_snippets) >= 1, "Should find at least 1 snippet for 'jordan'"
|
||||
|
||||
multi_word_snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
assert len(multi_word_snippets) == 1, (
|
||||
"Should return exactly 1 snippet for 'user jordan' "
|
||||
"(only the exact phrase match, not individual word occurrences)"
|
||||
)
|
||||
|
||||
snippet = multi_word_snippets[0]
|
||||
assert (
|
||||
"user jordan" in snippet.lower()
|
||||
), "The snippet should contain the exact phrase 'user jordan'"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), "The snippet should not include the first standalone 'user' with Alice"
|
||||
|
||||
def test_multi_word_query_without_exact_match(self):
|
||||
"""Test snippet generation when exact phrase is not found."""
|
||||
sample_text = """User Alice is here. Bob and jordan are talking.
|
||||
Later jordan mentions something. The user is happy."""
|
||||
|
||||
snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
|
||||
assert (
|
||||
len(snippets) >= 1
|
||||
), "Should find at least 1 snippet when falling back to first word"
|
||||
|
||||
all_snippets_text = " ".join(snippets).lower()
|
||||
assert (
|
||||
"user" in all_snippets_text
|
||||
), "Snippets should contain 'user' (the first word)"
|
||||
|
||||
def test_exact_phrase_at_text_boundaries(self):
|
||||
"""Test snippet generation when exact phrase appears at text boundaries."""
|
||||
|
||||
text_start = "user jordan started the meeting. Other content here."
|
||||
snippets = SnippetGenerator.generate(text_start, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
text_end = "Other content here. The meeting ended with user jordan"
|
||||
snippets = SnippetGenerator.generate(text_end, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
def test_multi_word_query_matches_words_appearing_separately_and_together(self):
|
||||
"""Test that multi-word queries prioritize exact phrase matches over individual word occurrences."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
search_query = "user jordan"
|
||||
snippets = SnippetGenerator.generate(sample_text, search_query)
|
||||
|
||||
assert len(snippets) == 1, (
|
||||
f"Expected exactly 1 snippet for '{search_query}' when exact phrase exists, "
|
||||
f"got {len(snippets)}. Should ignore individual word occurrences."
|
||||
)
|
||||
|
||||
snippet = snippets[0]
|
||||
|
||||
assert (
|
||||
search_query in snippet.lower()
|
||||
), f"Snippet should contain the exact phrase '{search_query}'. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"jordan mentions" in snippet.lower()
|
||||
), f"Snippet should include context before the exact phrase match. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), f"Snippet should not include separate occurrences of individual words. Got: {snippet}"
|
||||
|
||||
text_2 = """The alpha version was released.
|
||||
Beta testing started yesterday.
|
||||
The alpha beta integration is complete."""
|
||||
|
||||
snippets_2 = SnippetGenerator.generate(text_2, "alpha beta")
|
||||
assert len(snippets_2) == 1, "Should return 1 snippet for exact phrase match"
|
||||
assert "alpha beta" in snippets_2[0].lower(), "Should contain exact phrase"
|
||||
assert (
|
||||
"version" not in snippets_2[0].lower()
|
||||
), "Should not include first separate occurrence"
|
||||
|
||||
|
||||
class TestSnippetGenerationEnhanced:
|
||||
"""Additional snippet generation tests from test_search_enhancements.py."""
|
||||
|
||||
def test_snippet_generation_from_webvtt(self):
|
||||
"""Test snippet generation from WebVTT content."""
|
||||
webvtt_content = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:05.000
|
||||
This is the beginning of the transcript
|
||||
|
||||
00:00:05.000 --> 00:00:10.000
|
||||
The search term appears here in the middle
|
||||
|
||||
00:00:10.000 --> 00:00:15.000
|
||||
And this is the end of the content"""
|
||||
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt_content)
|
||||
snippets = SnippetGenerator.generate(plain_text, "search term")
|
||||
|
||||
assert len(snippets) > 0
|
||||
assert any("search term" in snippet.lower() for snippet in snippets)
|
||||
|
||||
def test_extract_webvtt_text_with_malformed_variations(self):
|
||||
"""Test WebVTT extraction with various malformed content."""
|
||||
malformed_vtt = "This is not valid WebVTT content"
|
||||
result = WebVTTProcessor.extract_text(malformed_vtt)
|
||||
assert result == ""
|
||||
|
||||
partial_vtt = "WEBVTT\nNo timestamps here"
|
||||
result = WebVTTProcessor.extract_text(partial_vtt)
|
||||
assert result == "" or "No timestamps" not in result
|
||||
|
||||
|
||||
class TestPureFunctions:
|
||||
"""Test the pure functions extracted for functional programming."""
|
||||
|
||||
def test_find_all_matches(self):
|
||||
"""Test finding all match positions in text."""
|
||||
text = "Python is great. Python is powerful. I love Python."
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Ruby"))
|
||||
assert matches == []
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches("", "test"))
|
||||
assert matches == []
|
||||
matches = list(SnippetGenerator.find_all_matches("test", ""))
|
||||
assert matches == []
|
||||
|
||||
def test_create_snippet(self):
|
||||
"""Test creating a snippet from a match position."""
|
||||
text = "This is a long text with the word Python in the middle and more text after."
|
||||
|
||||
snippet = SnippetGenerator.create_snippet(text, 35, max_length=150)
|
||||
assert "Python" in snippet.text()
|
||||
assert snippet.start >= 0
|
||||
assert snippet.end <= len(text)
|
||||
assert isinstance(snippet, SnippetCandidate)
|
||||
|
||||
assert len(snippet.text()) > 0
|
||||
assert snippet.start <= snippet.end
|
||||
|
||||
long_text = "A" * 200
|
||||
snippet = SnippetGenerator.create_snippet(long_text, 100, max_length=50)
|
||||
assert snippet.text().startswith("...")
|
||||
assert snippet.text().endswith("...")
|
||||
|
||||
snippet = SnippetGenerator.create_snippet("short text", 0, max_length=100)
|
||||
assert snippet.start == 0
|
||||
assert "short text" in snippet.text()
|
||||
|
||||
def test_filter_non_overlapping(self):
|
||||
"""Test filtering overlapping snippets."""
|
||||
candidates = [
|
||||
SnippetCandidate(_text="First snippet", start=0, _original_text_length=100),
|
||||
SnippetCandidate(_text="Overlapping", start=10, _original_text_length=100),
|
||||
SnippetCandidate(
|
||||
_text="Third snippet", start=40, _original_text_length=100
|
||||
),
|
||||
SnippetCandidate(
|
||||
_text="Fourth snippet", start=65, _original_text_length=100
|
||||
),
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter(candidates)))
|
||||
assert filtered == [
|
||||
"First snippet...",
|
||||
"...Third snippet...",
|
||||
"...Fourth snippet...",
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter([])))
|
||||
assert filtered == []
|
||||
|
||||
def test_generate_integration(self):
|
||||
"""Test the main SnippetGenerator.generate function."""
|
||||
text = "Machine learning is amazing. Machine learning transforms data. Learn machine learning today."
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
assert len(snippets) <= 3
|
||||
assert all("machine learning" in s.lower() for s in snippets)
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning", max_snippets=2)
|
||||
assert len(snippets) <= 2
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine vision")
|
||||
assert len(snippets) > 0
|
||||
assert any("machine" in s.lower() for s in snippets)
|
||||
|
||||
def test_extract_webvtt_text_basic(self):
|
||||
"""Test WebVTT text extraction (basic test, full tests exist elsewhere)."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Hello world
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
This is a test"""
|
||||
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
assert "Hello world" in result
|
||||
assert "This is a test" in result
|
||||
|
||||
# Test empty input
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
assert WebVTTProcessor.extract_text(None) == ""
|
||||
|
||||
def test_generate_webvtt_snippets(self):
|
||||
"""Test generating snippets from WebVTT content."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Python programming is great
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Learn Python today"""
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets(webvtt, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets("", "Python")
|
||||
assert snippets == []
|
||||
|
||||
def test_from_summary(self):
|
||||
"""Test generating snippets from summary text."""
|
||||
summary = "This meeting discussed Python development and machine learning applications."
|
||||
|
||||
snippets = SnippetGenerator.from_summary(summary, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
long_summary = "Python " * 20
|
||||
snippets = SnippetGenerator.from_summary(long_summary, "Python")
|
||||
assert len(snippets) <= 2
|
||||
|
||||
def test_combine_sources(self):
|
||||
"""Test combining snippets from multiple sources."""
|
||||
summary = "Python is a great programming language."
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Learn Python programming
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Python is powerful"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) > 0
|
||||
assert total_count > 0
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert all("Python" in s for s in snippets)
|
||||
assert total_count == 1
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert total_count == 2
|
||||
|
||||
long_summary = "Python " * 10
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, "Python", max_total=2
|
||||
)
|
||||
assert len(snippets) == 2
|
||||
assert total_count >= 10
|
||||
|
||||
def test_match_counting_sum_logic(self):
|
||||
"""Test that match counting correctly sums matches from both sources."""
|
||||
summary = "data science uses data analysis and data mining techniques"
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Big data processing
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
data visualization and data storage"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "data", max_total=3
|
||||
)
|
||||
assert total_count == 6
|
||||
assert len(snippets) <= 3
|
||||
|
||||
summary_snippets, summary_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "data", max_total=3
|
||||
)
|
||||
assert summary_count == 3
|
||||
|
||||
webvtt_snippets, webvtt_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "data", max_total=3
|
||||
)
|
||||
assert webvtt_count == 3
|
||||
|
||||
snippets_empty, count_empty = SnippetGenerator.combine_sources(
|
||||
None, None, "data", max_total=3
|
||||
)
|
||||
assert snippets_empty == []
|
||||
assert count_empty == 0
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for the pure functions."""
|
||||
text = "Test with special: @#$%^&*() characters"
|
||||
snippets = SnippetGenerator.generate(text, "@#$%")
|
||||
assert len(snippets) > 0
|
||||
|
||||
long_query = "a" * 100
|
||||
snippets = SnippetGenerator.generate("Some text", long_query)
|
||||
assert snippets == []
|
||||
|
||||
text = "Unicode test: café, naïve, 日本語"
|
||||
snippets = SnippetGenerator.generate(text, "café")
|
||||
assert len(snippets) > 0
|
||||
assert "café" in snippets[0]
|
||||
|
||||
Reference in New Issue
Block a user