feat: search frontend (#551)

* feat: better highlight

* feat(search): add long_summary to search vector for improved search results

- Update search vector to include long_summary with weight B (between title A and webvtt C)
- Modify SearchController to fetch long_summary and prioritize its snippets
- Generate snippets from long_summary first (max 2), then from webvtt for remaining slots
- Add comprehensive tests for long_summary search functionality
- Create migration to update search_vector_en column in PostgreSQL

This improves search quality by including summarized content which often contains
key topics and themes that may not be explicitly mentioned in the transcript.

* fix: address code review feedback for search enhancements

- Fix test file inconsistencies by removing references to non-existent model fields
  - Comment out tests for unimplemented features (room_ids, status filters, date ranges)
  - Update tests to only use currently available fields (room_id singular, no room_name/processing_status)
  - Mark future functionality tests with @pytest.mark.skip

- Make snippet counts configurable
  - Add LONG_SUMMARY_MAX_SNIPPETS constant (default: 2)
  - Replace hardcoded value with configurable constant

- Improve error handling consistency in WebVTT parsing
  - Use different log levels for different error types (debug for malformed, warning for decode, error for unexpected)
  - Add catch-all exception handler for unexpected errors
  - Include stack trace for critical errors

All existing tests pass with these changes.

* fix: correct datetime test to include required duration field

* feat: better highlight

* feat: search room names

* feat: acknowledge deleted room

* feat: search filters fix and rank removal

* chore: minor refactoring

* feat: better matches frontend

* chore: self-review (vibe)

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* chore: self-review WIP

* remove swc (vibe)

* search url query sync (vibe)

* search url query sync (vibe)

* better casts and cap while

* PR review + simplify frontend hook

* pr: remove search db timeouts

* cleanup tests

* tests cleanup

* frontend cleanup

* index declarations

* refactor frontend (self-review)

* fix search pagination

* clear "x" for search input

* pagination max pages fix

* chore: cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* lockfile

* pr review
This commit is contained in:
Igor Loskutov
2025-08-20 20:56:45 -04:00
committed by GitHub
parent fe5d344cff
commit 009590c080
32 changed files with 2311 additions and 618 deletions

View File

@@ -2,13 +2,18 @@
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import ValidationError
from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
from reflector.db.search import (
SearchController,
SearchParameters,
SearchResult,
search_controller,
)
from reflector.db.transcripts import SourceKind, transcripts
@pytest.mark.asyncio
@@ -18,39 +23,135 @@ async def test_search_postgresql_only():
assert results == []
assert total == 0
try:
SearchParameters(query_text="")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" ")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
params_empty = SearchParameters(query_text="")
results_empty, total_empty = await search_controller.search_transcripts(
params_empty
)
assert isinstance(results_empty, list)
assert isinstance(total_empty, int)
@pytest.mark.asyncio
async def test_search_input_validation():
try:
SearchParameters(query_text="")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
async def test_search_with_empty_query():
"""Test that empty query returns all transcripts."""
params = SearchParameters(query_text="")
results, total = await search_controller.search_transcripts(params)
assert isinstance(results, list)
assert isinstance(total, int)
if len(results) > 1:
for i in range(len(results) - 1):
assert results[i].created_at >= results[i + 1].created_at
@pytest.mark.asyncio
async def test_empty_transcript_title_only_match():
"""Test that transcripts with title-only matches return empty snippets."""
test_id = "test-empty-9b3f2a8d"
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" \t\n ")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Empty Transcript",
"title": "Empty Meeting",
"status": "completed",
"locked": False,
"duration": 0.0,
"created_at": datetime.now(timezone.utc),
"short_summary": None,
"long_summary": None,
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": None,
}
await get_database().execute(transcripts.insert().values(**test_data))
params = SearchParameters(query_text="empty")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = next((r for r in results if r.id == test_id), None)
assert found is not None, "Should find transcript by title match"
assert found.search_snippets == []
assert found.total_match_count == 0
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.mark.asyncio
async def test_search_with_long_summary():
"""Test that long_summary content is searchable."""
test_id = "test-long-summary-8a9f3c2d"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Test Long Summary",
"title": "Regular Meeting",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(timezone.utc),
"short_summary": "Brief overview",
"long_summary": "Detailed discussion about quantum computing applications and blockchain technology integration",
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
Basic meeting content without special keywords.""",
}
await get_database().execute(transcripts.insert().values(**test_data))
params = SearchParameters(query_text="quantum computing")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find transcript by long_summary content"
test_result = next((r for r in results if r.id == test_id), None)
assert test_result
assert len(test_result.search_snippets) > 0
assert "quantum computing" in test_result.search_snippets[0].lower()
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.mark.asyncio
async def test_postgresql_search_with_data():
# collision is improbable
test_id = "test-search-e2e-7f3a9b2c"
try:
@@ -94,28 +195,24 @@ We need to implement PostgreSQL tsvector for better performance.""",
await get_database().execute(transcripts.insert().values(**test_data))
# Test 1: Search for a word in title
params = SearchParameters(query_text="planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
# Test 2: Search for a word in webvtt content
params = SearchParameters(query_text="tsvector")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
# Test 3: Search with multiple words
params = SearchParameters(query_text="engineering planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
# Test 4: Verify SearchResult structure
test_result = next((r for r in results if r.id == test_id), None)
if test_result:
assert test_result.title == "Engineering Planning Meeting Q4 2024"
@@ -123,14 +220,12 @@ We need to implement PostgreSQL tsvector for better performance.""",
assert test_result.duration == 1800.0
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
# Test 5: Search with OR operator
params = SearchParameters(query_text="tsvector OR nosuchword")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
# Test 6: Quoted phrase search
params = SearchParameters(query_text='"full-text search"')
results, total = await search_controller.search_transcripts(params)
assert total >= 1
@@ -142,3 +237,240 @@ We need to implement PostgreSQL tsvector for better performance.""",
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.fixture
def sample_search_params():
"""Create sample search parameters for testing."""
return SearchParameters(
query_text="test query",
limit=20,
offset=0,
user_id="test-user",
room_id="room1",
)
@pytest.fixture
def mock_db_result():
"""Create a mock database result."""
return {
"id": "test-transcript-id",
"title": "Test Transcript",
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
"duration": 3600.0,
"status": "completed",
"user_id": "test-user",
"room_id": "room1",
"source_kind": SourceKind.LIVE,
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\nThis is a test transcript",
"rank": 0.95,
}
class TestSearchParameters:
"""Test SearchParameters model validation and functionality."""
def test_search_parameters_with_available_filters(self):
"""Test creating SearchParameters with currently available filter options."""
params = SearchParameters(
query_text="search term",
limit=50,
offset=10,
user_id="user123",
room_id="room1",
)
assert params.query_text == "search term"
assert params.limit == 50
assert params.offset == 10
assert params.user_id == "user123"
assert params.room_id == "room1"
def test_search_parameters_defaults(self):
"""Test SearchParameters with default values."""
params = SearchParameters(query_text="test")
assert params.query_text == "test"
assert params.limit == 20
assert params.offset == 0
assert params.user_id is None
assert params.room_id is None
class TestSearchControllerFilters:
"""Test SearchController functionality with various filters."""
@pytest.mark.asyncio
async def test_search_with_source_kind_filter(self):
"""Test search filtering by source_kind."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
results, total = await controller.search_transcripts(params)
assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio
async def test_search_with_single_room_id(self):
"""Test search filtering by single room ID (currently supported)."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
params = SearchParameters(
query_text="test",
room_id="room1",
)
results, total = await controller.search_transcripts(params)
assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio
async def test_search_result_includes_available_fields(self, mock_db_result):
"""Test that search results include available fields like source_kind."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
class MockRow:
def __init__(self, data):
self._data = data
self._mapping = data
def __iter__(self):
return iter(self._data.items())
def __getitem__(self, key):
return self._data[key]
def keys(self):
return self._data.keys()
mock_row = MockRow(mock_db_result)
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
params = SearchParameters(query_text="test")
results, total = await controller.search_transcripts(params)
assert total == 1
assert len(results) == 1
result = results[0]
assert isinstance(result, SearchResult)
assert result.id == "test-transcript-id"
assert result.title == "Test Transcript"
assert result.rank == 0.95
class TestSearchEndpointParsing:
"""Test parameter parsing in the search endpoint."""
def test_parse_comma_separated_room_ids(self):
"""Test parsing comma-separated room IDs."""
room_ids_str = "room1,room2,room3"
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
assert parsed == ["room1", "room2", "room3"]
room_ids_str = "room1, room2 , room3"
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
assert parsed == ["room1", "room2", "room3"]
room_ids_str = "room1,,room3,"
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
assert parsed == ["room1", "room3"]
def test_parse_source_kind(self):
"""Test parsing source_kind values."""
for kind_str in ["live", "file", "room"]:
parsed = SourceKind(kind_str)
assert parsed == SourceKind(kind_str)
with pytest.raises(ValueError):
SourceKind("invalid_kind")
class TestSearchResultModel:
"""Test SearchResult model and serialization."""
def test_search_result_with_available_fields(self):
"""Test SearchResult model with currently available fields populated."""
result = SearchResult(
id="test-id",
title="Test Title",
user_id="user-123",
room_id="room-456",
source_kind=SourceKind.ROOM,
created_at=datetime(2024, 6, 15, tzinfo=timezone.utc),
status="completed",
rank=0.85,
duration=1800.5,
search_snippets=["snippet 1", "snippet 2"],
)
assert result.id == "test-id"
assert result.title == "Test Title"
assert result.user_id == "user-123"
assert result.room_id == "room-456"
assert result.status == "completed"
assert result.rank == 0.85
assert result.duration == 1800.5
assert len(result.search_snippets) == 2
def test_search_result_with_optional_fields_none(self):
"""Test SearchResult model with optional fields as None."""
result = SearchResult(
id="test-id",
source_kind=SourceKind.FILE,
created_at=datetime.now(timezone.utc),
status="processing",
rank=0.5,
search_snippets=[],
title=None,
user_id=None,
room_id=None,
duration=None,
)
assert result.title is None
assert result.user_id is None
assert result.room_id is None
assert result.duration is None
def test_search_result_datetime_field(self):
"""Test that SearchResult accepts datetime field."""
result = SearchResult(
id="test-id",
source_kind=SourceKind.LIVE,
created_at=datetime(2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc),
status="completed",
rank=0.9,
duration=None,
search_snippets=[],
)
assert result.created_at == datetime(
2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc
)

View File

@@ -0,0 +1,164 @@
"""Tests for long_summary in search functionality."""
import json
from datetime import datetime, timezone
import pytest
from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
@pytest.mark.asyncio
async def test_long_summary_snippet_prioritization():
"""Test that snippets from long_summary are prioritized over webvtt content."""
test_id = "test-snippet-priority-3f9a2b8c"
try:
# Clean up any existing test data
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Test Snippet Priority",
"title": "Meeting About Projects",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(timezone.utc),
"short_summary": "Project discussion",
"long_summary": (
"The team discussed advanced robotics applications including "
"autonomous navigation systems and sensor fusion techniques. "
"Robotics development will focus on real-time processing."
),
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
We talked about many different topics today.
00:00:10.000 --> 00:00:20.000
The robotics project is making good progress.
00:00:20.000 --> 00:00:30.000
We need to consider various implementation approaches.""",
}
await get_database().execute(transcripts.insert().values(**test_data))
# Search for "robotics" which appears in both long_summary and webvtt
params = SearchParameters(query_text="robotics")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
test_result = next((r for r in results if r.id == test_id), None)
assert test_result, "Should find the test transcript"
snippets = test_result.search_snippets
assert len(snippets) > 0, "Should have at least one snippet"
# The first snippets should be from long_summary (more detailed content)
first_snippet = snippets[0].lower()
assert (
"advanced robotics" in first_snippet or "autonomous" in first_snippet
), f"First snippet should be from long_summary with detailed content. Got: {snippets[0]}"
# With max 3 snippets, we should get both from long_summary and webvtt
assert len(snippets) <= 3, "Should respect max snippets limit"
# All snippets should contain the search term
for snippet in snippets:
assert (
"robotics" in snippet.lower()
), f"Snippet should contain search term: {snippet}"
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()
@pytest.mark.asyncio
async def test_long_summary_only_search():
"""Test searching for content that only exists in long_summary."""
test_id = "test-long-only-8b3c9f2a"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Test Long Only",
"title": "Standard Meeting",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(timezone.utc),
"short_summary": "Team sync",
"long_summary": (
"Detailed analysis of cryptocurrency market trends and "
"decentralized finance protocols. Discussion included "
"yield farming strategies and liquidity pool mechanics."
),
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
Team meeting about general project updates.
00:00:10.000 --> 00:00:20.000
Discussion of timeline and deliverables.""",
}
await get_database().execute(transcripts.insert().values(**test_data))
# Search for terms only in long_summary
params = SearchParameters(query_text="cryptocurrency")
results, total = await search_controller.search_transcripts(params)
found = any(r.id == test_id for r in results)
assert found, "Should find transcript by long_summary-only content"
test_result = next((r for r in results if r.id == test_id), None)
assert test_result
assert len(test_result.search_snippets) > 0
# Verify the snippet is about cryptocurrency
snippet = test_result.search_snippets[0].lower()
assert "cryptocurrency" in snippet, "Snippet should contain the search term"
# Search for "yield farming" - a more specific term
params2 = SearchParameters(query_text="yield farming")
results2, total2 = await search_controller.search_transcripts(params2)
found2 = any(r.id == test_id for r in results2)
assert found2, "Should find transcript by specific long_summary phrase"
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()

View File

@@ -1,6 +1,10 @@
"""Unit tests for search snippet generation."""
from reflector.db.search import SearchController
from reflector.db.search import (
SnippetCandidate,
SnippetGenerator,
WebVTTProcessor,
)
class TestExtractWebVTT:
@@ -16,7 +20,7 @@ class TestExtractWebVTT:
00:00:10.000 --> 00:00:20.000
<v Speaker1>Indeed it is a test of WebVTT parsing.
"""
result = SearchController._extract_webvtt_text(webvtt)
result = WebVTTProcessor.extract_text(webvtt)
assert "Hello world, this is a test" in result
assert "Indeed it is a test" in result
assert "<v Speaker" not in result
@@ -25,12 +29,11 @@ class TestExtractWebVTT:
def test_extract_empty_webvtt(self):
"""Test empty WebVTT returns empty string."""
assert SearchController._extract_webvtt_text("") == ""
assert SearchController._extract_webvtt_text(None) == ""
assert WebVTTProcessor.extract_text("") == ""
def test_extract_malformed_webvtt(self):
"""Test malformed WebVTT returns empty string."""
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
result = WebVTTProcessor.extract_text("Not a valid WebVTT")
assert result == ""
@@ -39,8 +42,7 @@ class TestGenerateSnippets:
def test_multiple_matches(self):
"""Test finding multiple occurrences of search term in long text."""
# Create text with Python mentions far apart to get separate snippets
separator = " This is filler text. " * 20 # ~400 chars of padding
separator = " This is filler text. " * 20
text = (
"Python is great for machine learning."
+ separator
@@ -51,18 +53,16 @@ class TestGenerateSnippets:
+ "The Python community is very supportive."
)
snippets = SearchController._generate_snippets(text, "Python")
# With enough separation, we should get multiple snippets
assert len(snippets) >= 2 # At least 2 distinct snippets
snippets = SnippetGenerator.generate(text, "Python")
assert len(snippets) >= 2
# Each snippet should contain "Python"
for snippet in snippets:
assert "python" in snippet.lower()
def test_single_match(self):
"""Test single occurrence returns one snippet."""
text = "This document discusses artificial intelligence and its applications."
snippets = SearchController._generate_snippets(text, "artificial intelligence")
snippets = SnippetGenerator.generate(text, "artificial intelligence")
assert len(snippets) == 1
assert "artificial intelligence" in snippets[0].lower()
@@ -70,24 +70,22 @@ class TestGenerateSnippets:
def test_no_matches(self):
"""Test no matches returns empty list."""
text = "This is some random text without the search term."
snippets = SearchController._generate_snippets(text, "machine learning")
snippets = SnippetGenerator.generate(text, "machine learning")
assert snippets == []
def test_case_insensitive_search(self):
"""Test search is case insensitive."""
# Add enough text between matches to get separate snippets
text = (
"MACHINE LEARNING is important for modern applications. "
+ "It requires lots of data and computational resources. " * 5 # Padding
+ "It requires lots of data and computational resources. " * 5
+ "Machine Learning rocks and transforms industries. "
+ "Deep learning is a subset of it. " * 5 # More padding
+ "Deep learning is a subset of it. " * 5
+ "Finally, machine learning will shape our future."
)
snippets = SearchController._generate_snippets(text, "machine learning")
snippets = SnippetGenerator.generate(text, "machine learning")
# Should find at least 2 (might be 3 if text is long enough)
assert len(snippets) >= 2
for snippet in snippets:
assert "machine learning" in snippet.lower()
@@ -95,61 +93,55 @@ class TestGenerateSnippets:
def test_partial_match_fallback(self):
"""Test fallback to first word when exact phrase not found."""
text = "We use machine intelligence for processing."
snippets = SearchController._generate_snippets(text, "machine learning")
snippets = SnippetGenerator.generate(text, "machine learning")
# Should fall back to finding "machine"
assert len(snippets) == 1
assert "machine" in snippets[0].lower()
def test_snippet_ellipsis(self):
"""Test ellipsis added for truncated snippets."""
# Long text where match is in the middle
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
snippets = SnippetGenerator.generate(text, "TARGET_WORD")
assert len(snippets) == 1
assert "..." in snippets[0] # Should have ellipsis
assert "..." in snippets[0]
assert "TARGET_WORD" in snippets[0]
def test_overlapping_snippets_deduplicated(self):
"""Test overlapping matches don't create duplicate snippets."""
text = "test test test word" * 10 # Repeated pattern
snippets = SearchController._generate_snippets(text, "test")
text = "test test test word" * 10
snippets = SnippetGenerator.generate(text, "test")
# Should get unique snippets, not duplicates
assert len(snippets) <= 3
assert len(snippets) == len(set(snippets)) # All unique
assert len(snippets) == len(set(snippets))
def test_empty_inputs(self):
"""Test empty text or search term returns empty list."""
assert SearchController._generate_snippets("", "search") == []
assert SearchController._generate_snippets("text", "") == []
assert SearchController._generate_snippets("", "") == []
assert SnippetGenerator.generate("", "search") == []
assert SnippetGenerator.generate("text", "") == []
assert SnippetGenerator.generate("", "") == []
def test_max_snippets_limit(self):
"""Test respects max_snippets parameter."""
# Create text with well-separated occurrences
separator = " filler " * 50 # Ensure snippets don't overlap
text = ("Python is amazing" + separator) * 10 # 10 occurrences
separator = " filler " * 50
text = ("Python is amazing" + separator) * 10
# Test with different limits
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
snippets_1 = SnippetGenerator.generate(text, "Python", max_snippets=1)
assert len(snippets_1) == 1
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
snippets_2 = SnippetGenerator.generate(text, "Python", max_snippets=2)
assert len(snippets_2) == 2
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
snippets_5 = SnippetGenerator.generate(text, "Python", max_snippets=5)
assert len(snippets_5) == 5
def test_snippet_length(self):
"""Test snippet length is reasonable."""
text = "word " * 200 # Long text
snippets = SearchController._generate_snippets(text, "word")
text = "word " * 200
snippets = SnippetGenerator.generate(text, "word")
for snippet in snippets:
# Default max_length is 150 + some context
assert len(snippet) <= 200 # Some buffer for ellipsis
assert len(snippet) <= 200
class TestFullPipeline:
@@ -157,7 +149,6 @@ class TestFullPipeline:
def test_webvtt_to_snippets_integration(self):
"""Test full pipeline from WebVTT to search snippets."""
# Create WebVTT with well-separated content for multiple snippets
webvtt = (
"""WEBVTT
@@ -182,17 +173,362 @@ class TestFullPipeline:
"""
)
# Extract and generate snippets
plain_text = SearchController._extract_webvtt_text(webvtt)
snippets = SearchController._generate_snippets(plain_text, "machine learning")
plain_text = WebVTTProcessor.extract_text(webvtt)
snippets = SnippetGenerator.generate(plain_text, "machine learning")
# Should find at least 2 snippets (text might still be close together)
assert len(snippets) >= 1 # At minimum one snippet containing matches
assert len(snippets) <= 3 # At most 3 by default
assert len(snippets) >= 1
assert len(snippets) <= 3
# No WebVTT artifacts in snippets
for snippet in snippets:
assert "machine learning" in snippet.lower()
assert "<v Speaker" not in snippet
assert "00:00" not in snippet
assert "-->" not in snippet
class TestMultiWordQueryBehavior:
"""Tests for multi-word query behavior and exact phrase matching."""
def test_multi_word_query_snippet_behavior(self):
"""Test that multi-word queries generate snippets based on exact phrase matching."""
sample_text = """This is a sample transcript where user Alice is talking.
Later in the conversation, jordan mentions something important.
The user jordan collaboration was successful.
Another user named Bob joins the discussion."""
user_snippets = SnippetGenerator.generate(sample_text, "user")
assert len(user_snippets) == 2, "Should find 2 snippets for 'user'"
jordan_snippets = SnippetGenerator.generate(sample_text, "jordan")
assert len(jordan_snippets) >= 1, "Should find at least 1 snippet for 'jordan'"
multi_word_snippets = SnippetGenerator.generate(sample_text, "user jordan")
assert len(multi_word_snippets) == 1, (
"Should return exactly 1 snippet for 'user jordan' "
"(only the exact phrase match, not individual word occurrences)"
)
snippet = multi_word_snippets[0]
assert (
"user jordan" in snippet.lower()
), "The snippet should contain the exact phrase 'user jordan'"
assert (
"alice" not in snippet.lower()
), "The snippet should not include the first standalone 'user' with Alice"
def test_multi_word_query_without_exact_match(self):
"""Test snippet generation when exact phrase is not found."""
sample_text = """User Alice is here. Bob and jordan are talking.
Later jordan mentions something. The user is happy."""
snippets = SnippetGenerator.generate(sample_text, "user jordan")
assert (
len(snippets) >= 1
), "Should find at least 1 snippet when falling back to first word"
all_snippets_text = " ".join(snippets).lower()
assert (
"user" in all_snippets_text
), "Snippets should contain 'user' (the first word)"
def test_exact_phrase_at_text_boundaries(self):
"""Test snippet generation when exact phrase appears at text boundaries."""
text_start = "user jordan started the meeting. Other content here."
snippets = SnippetGenerator.generate(text_start, "user jordan")
assert len(snippets) == 1
assert "user jordan" in snippets[0].lower()
text_end = "Other content here. The meeting ended with user jordan"
snippets = SnippetGenerator.generate(text_end, "user jordan")
assert len(snippets) == 1
assert "user jordan" in snippets[0].lower()
def test_multi_word_query_matches_words_appearing_separately_and_together(self):
"""Test that multi-word queries prioritize exact phrase matches over individual word occurrences."""
sample_text = """This is a sample transcript where user Alice is talking.
Later in the conversation, jordan mentions something important.
The user jordan collaboration was successful.
Another user named Bob joins the discussion."""
search_query = "user jordan"
snippets = SnippetGenerator.generate(sample_text, search_query)
assert len(snippets) == 1, (
f"Expected exactly 1 snippet for '{search_query}' when exact phrase exists, "
f"got {len(snippets)}. Should ignore individual word occurrences."
)
snippet = snippets[0]
assert (
search_query in snippet.lower()
), f"Snippet should contain the exact phrase '{search_query}'. Got: {snippet}"
assert (
"jordan mentions" in snippet.lower()
), f"Snippet should include context before the exact phrase match. Got: {snippet}"
assert (
"alice" not in snippet.lower()
), f"Snippet should not include separate occurrences of individual words. Got: {snippet}"
text_2 = """The alpha version was released.
Beta testing started yesterday.
The alpha beta integration is complete."""
snippets_2 = SnippetGenerator.generate(text_2, "alpha beta")
assert len(snippets_2) == 1, "Should return 1 snippet for exact phrase match"
assert "alpha beta" in snippets_2[0].lower(), "Should contain exact phrase"
assert (
"version" not in snippets_2[0].lower()
), "Should not include first separate occurrence"
class TestSnippetGenerationEnhanced:
"""Additional snippet generation tests from test_search_enhancements.py."""
def test_snippet_generation_from_webvtt(self):
"""Test snippet generation from WebVTT content."""
webvtt_content = """WEBVTT
00:00:00.000 --> 00:00:05.000
This is the beginning of the transcript
00:00:05.000 --> 00:00:10.000
The search term appears here in the middle
00:00:10.000 --> 00:00:15.000
And this is the end of the content"""
plain_text = WebVTTProcessor.extract_text(webvtt_content)
snippets = SnippetGenerator.generate(plain_text, "search term")
assert len(snippets) > 0
assert any("search term" in snippet.lower() for snippet in snippets)
def test_extract_webvtt_text_with_malformed_variations(self):
"""Test WebVTT extraction with various malformed content."""
malformed_vtt = "This is not valid WebVTT content"
result = WebVTTProcessor.extract_text(malformed_vtt)
assert result == ""
partial_vtt = "WEBVTT\nNo timestamps here"
result = WebVTTProcessor.extract_text(partial_vtt)
assert result == "" or "No timestamps" not in result
class TestPureFunctions:
"""Test the pure functions extracted for functional programming."""
def test_find_all_matches(self):
"""Test finding all match positions in text."""
text = "Python is great. Python is powerful. I love Python."
matches = list(SnippetGenerator.find_all_matches(text, "Python"))
assert matches == [0, 17, 44]
matches = list(SnippetGenerator.find_all_matches(text, "python"))
assert matches == [0, 17, 44]
matches = list(SnippetGenerator.find_all_matches(text, "Ruby"))
assert matches == []
matches = list(SnippetGenerator.find_all_matches("", "test"))
assert matches == []
matches = list(SnippetGenerator.find_all_matches("test", ""))
assert matches == []
def test_create_snippet(self):
"""Test creating a snippet from a match position."""
text = "This is a long text with the word Python in the middle and more text after."
snippet = SnippetGenerator.create_snippet(text, 35, max_length=150)
assert "Python" in snippet.text()
assert snippet.start >= 0
assert snippet.end <= len(text)
assert isinstance(snippet, SnippetCandidate)
assert len(snippet.text()) > 0
assert snippet.start <= snippet.end
long_text = "A" * 200
snippet = SnippetGenerator.create_snippet(long_text, 100, max_length=50)
assert snippet.text().startswith("...")
assert snippet.text().endswith("...")
snippet = SnippetGenerator.create_snippet("short text", 0, max_length=100)
assert snippet.start == 0
assert "short text" in snippet.text()
def test_filter_non_overlapping(self):
"""Test filtering overlapping snippets."""
candidates = [
SnippetCandidate(_text="First snippet", start=0, _original_text_length=100),
SnippetCandidate(_text="Overlapping", start=10, _original_text_length=100),
SnippetCandidate(
_text="Third snippet", start=40, _original_text_length=100
),
SnippetCandidate(
_text="Fourth snippet", start=65, _original_text_length=100
),
]
filtered = list(SnippetGenerator.filter_non_overlapping(iter(candidates)))
assert filtered == [
"First snippet...",
"...Third snippet...",
"...Fourth snippet...",
]
filtered = list(SnippetGenerator.filter_non_overlapping(iter([])))
assert filtered == []
def test_generate_integration(self):
"""Test the main SnippetGenerator.generate function."""
text = "Machine learning is amazing. Machine learning transforms data. Learn machine learning today."
snippets = SnippetGenerator.generate(text, "machine learning")
assert len(snippets) <= 3
assert all("machine learning" in s.lower() for s in snippets)
snippets = SnippetGenerator.generate(text, "machine learning", max_snippets=2)
assert len(snippets) <= 2
snippets = SnippetGenerator.generate(text, "machine vision")
assert len(snippets) > 0
assert any("machine" in s.lower() for s in snippets)
def test_extract_webvtt_text_basic(self):
"""Test WebVTT text extraction (basic test, full tests exist elsewhere)."""
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Hello world
00:00:02.000 --> 00:00:04.000
This is a test"""
result = WebVTTProcessor.extract_text(webvtt)
assert "Hello world" in result
assert "This is a test" in result
# Test empty input
assert WebVTTProcessor.extract_text("") == ""
assert WebVTTProcessor.extract_text(None) == ""
def test_generate_webvtt_snippets(self):
"""Test generating snippets from WebVTT content."""
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Python programming is great
00:00:02.000 --> 00:00:04.000
Learn Python today"""
snippets = WebVTTProcessor.generate_snippets(webvtt, "Python")
assert len(snippets) > 0
assert any("Python" in s for s in snippets)
snippets = WebVTTProcessor.generate_snippets("", "Python")
assert snippets == []
def test_from_summary(self):
"""Test generating snippets from summary text."""
summary = "This meeting discussed Python development and machine learning applications."
snippets = SnippetGenerator.from_summary(summary, "Python")
assert len(snippets) > 0
assert any("Python" in s for s in snippets)
long_summary = "Python " * 20
snippets = SnippetGenerator.from_summary(long_summary, "Python")
assert len(snippets) <= 2
def test_combine_sources(self):
"""Test combining snippets from multiple sources."""
summary = "Python is a great programming language."
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Learn Python programming
00:00:02.000 --> 00:00:04.000
Python is powerful"""
snippets, total_count = SnippetGenerator.combine_sources(
summary, webvtt, "Python", max_total=3
)
assert len(snippets) <= 3
assert len(snippets) > 0
assert total_count > 0
snippets, total_count = SnippetGenerator.combine_sources(
summary, None, "Python", max_total=3
)
assert len(snippets) > 0
assert all("Python" in s for s in snippets)
assert total_count == 1
snippets, total_count = SnippetGenerator.combine_sources(
None, webvtt, "Python", max_total=3
)
assert len(snippets) > 0
assert total_count == 2
long_summary = "Python " * 10
snippets, total_count = SnippetGenerator.combine_sources(
long_summary, webvtt, "Python", max_total=2
)
assert len(snippets) == 2
assert total_count >= 10
def test_match_counting_sum_logic(self):
"""Test that match counting correctly sums matches from both sources."""
summary = "data science uses data analysis and data mining techniques"
webvtt = """WEBVTT
00:00:00.000 --> 00:00:02.000
Big data processing
00:00:02.000 --> 00:00:04.000
data visualization and data storage"""
snippets, total_count = SnippetGenerator.combine_sources(
summary, webvtt, "data", max_total=3
)
assert total_count == 6
assert len(snippets) <= 3
summary_snippets, summary_count = SnippetGenerator.combine_sources(
summary, None, "data", max_total=3
)
assert summary_count == 3
webvtt_snippets, webvtt_count = SnippetGenerator.combine_sources(
None, webvtt, "data", max_total=3
)
assert webvtt_count == 3
snippets_empty, count_empty = SnippetGenerator.combine_sources(
None, None, "data", max_total=3
)
assert snippets_empty == []
assert count_empty == 0
def test_edge_cases(self):
"""Test edge cases for the pure functions."""
text = "Test with special: @#$%^&*() characters"
snippets = SnippetGenerator.generate(text, "@#$%")
assert len(snippets) > 0
long_query = "a" * 100
snippets = SnippetGenerator.generate("Some text", long_query)
assert snippets == []
text = "Unicode test: café, naïve, 日本語"
snippets = SnippetGenerator.generate(text, "café")
assert len(snippets) > 0
assert "café" in snippets[0]