mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: search frontend (#551)
* feat: better highlight * feat(search): add long_summary to search vector for improved search results - Update search vector to include long_summary with weight B (between title A and webvtt C) - Modify SearchController to fetch long_summary and prioritize its snippets - Generate snippets from long_summary first (max 2), then from webvtt for remaining slots - Add comprehensive tests for long_summary search functionality - Create migration to update search_vector_en column in PostgreSQL This improves search quality by including summarized content which often contains key topics and themes that may not be explicitly mentioned in the transcript. * fix: address code review feedback for search enhancements - Fix test file inconsistencies by removing references to non-existent model fields - Comment out tests for unimplemented features (room_ids, status filters, date ranges) - Update tests to only use currently available fields (room_id singular, no room_name/processing_status) - Mark future functionality tests with @pytest.mark.skip - Make snippet counts configurable - Add LONG_SUMMARY_MAX_SNIPPETS constant (default: 2) - Replace hardcoded value with configurable constant - Improve error handling consistency in WebVTT parsing - Use different log levels for different error types (debug for malformed, warning for decode, error for unexpected) - Add catch-all exception handler for unexpected errors - Include stack trace for critical errors All existing tests pass with these changes. * fix: correct datetime test to include required duration field * feat: better highlight * feat: search room names * feat: acknowledge deleted room * feat: search filters fix and rank removal * chore: minor refactoring * feat: better matches frontend * chore: self-review (vibe) * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * chore: self-review WIP * remove swc (vibe) * search url query sync (vibe) * search url query sync (vibe) * better casts and cap while * PR review + simplify frontend hook * pr: remove search db timeouts * cleanup tests * tests cleanup * frontend cleanup * index declarations * refactor frontend (self-review) * fix search pagination * clear "x" for search input * pagination max pages fix * chore: cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * lockfile * pr review
This commit is contained in:
@@ -2,13 +2,18 @@
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
from reflector.db.search import (
|
||||
SearchController,
|
||||
SearchParameters,
|
||||
SearchResult,
|
||||
search_controller,
|
||||
)
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -18,39 +23,135 @@ async def test_search_postgresql_only():
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" ")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
params_empty = SearchParameters(query_text="")
|
||||
results_empty, total_empty = await search_controller.search_transcripts(
|
||||
params_empty
|
||||
)
|
||||
assert isinstance(results_empty, list)
|
||||
assert isinstance(total_empty, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_input_validation():
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
async def test_search_with_empty_query():
|
||||
"""Test that empty query returns all transcripts."""
|
||||
params = SearchParameters(query_text="")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
if len(results) > 1:
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i].created_at >= results[i + 1].created_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_title_only_match():
|
||||
"""Test that transcripts with title-only matches return empty snippets."""
|
||||
test_id = "test-empty-9b3f2a8d"
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" \t\n ")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Empty Transcript",
|
||||
"title": "Empty Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 0.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": None,
|
||||
"long_summary": None,
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": None,
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="empty")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = next((r for r in results if r.id == test_id), None)
|
||||
assert found is not None, "Should find transcript by title match"
|
||||
assert found.search_snippets == []
|
||||
assert found.total_match_count == 0
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_long_summary():
|
||||
"""Test that long_summary content is searchable."""
|
||||
test_id = "test-long-summary-8a9f3c2d"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Summary",
|
||||
"title": "Regular Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Brief overview",
|
||||
"long_summary": "Detailed discussion about quantum computing applications and blockchain technology integration",
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Basic meeting content without special keywords.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="quantum computing")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
assert "quantum computing" in test_result.search_snippets[0].lower()
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgresql_search_with_data():
|
||||
# collision is improbable
|
||||
test_id = "test-search-e2e-7f3a9b2c"
|
||||
|
||||
try:
|
||||
@@ -94,28 +195,24 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Test 1: Search for a word in title
|
||||
params = SearchParameters(query_text="planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by title word"
|
||||
|
||||
# Test 2: Search for a word in webvtt content
|
||||
params = SearchParameters(query_text="tsvector")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by webvtt content"
|
||||
|
||||
# Test 3: Search with multiple words
|
||||
params = SearchParameters(query_text="engineering planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by multiple words"
|
||||
|
||||
# Test 4: Verify SearchResult structure
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
if test_result:
|
||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||
@@ -123,14 +220,12 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
assert test_result.duration == 1800.0
|
||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||
|
||||
# Test 5: Search with OR operator
|
||||
params = SearchParameters(query_text="tsvector OR nosuchword")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript with OR query"
|
||||
|
||||
# Test 6: Quoted phrase search
|
||||
params = SearchParameters(query_text='"full-text search"')
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
@@ -142,3 +237,240 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_params():
|
||||
"""Create sample search parameters for testing."""
|
||||
return SearchParameters(
|
||||
query_text="test query",
|
||||
limit=20,
|
||||
offset=0,
|
||||
user_id="test-user",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_result():
|
||||
"""Create a mock database result."""
|
||||
return {
|
||||
"id": "test-transcript-id",
|
||||
"title": "Test Transcript",
|
||||
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
"duration": 3600.0,
|
||||
"status": "completed",
|
||||
"user_id": "test-user",
|
||||
"room_id": "room1",
|
||||
"source_kind": SourceKind.LIVE,
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\nThis is a test transcript",
|
||||
"rank": 0.95,
|
||||
}
|
||||
|
||||
|
||||
class TestSearchParameters:
|
||||
"""Test SearchParameters model validation and functionality."""
|
||||
|
||||
def test_search_parameters_with_available_filters(self):
|
||||
"""Test creating SearchParameters with currently available filter options."""
|
||||
params = SearchParameters(
|
||||
query_text="search term",
|
||||
limit=50,
|
||||
offset=10,
|
||||
user_id="user123",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
assert params.query_text == "search term"
|
||||
assert params.limit == 50
|
||||
assert params.offset == 10
|
||||
assert params.user_id == "user123"
|
||||
assert params.room_id == "room1"
|
||||
|
||||
def test_search_parameters_defaults(self):
|
||||
"""Test SearchParameters with default values."""
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
assert params.query_text == "test"
|
||||
assert params.limit == 20
|
||||
assert params.offset == 0
|
||||
assert params.user_id is None
|
||||
assert params.room_id is None
|
||||
|
||||
|
||||
class TestSearchControllerFilters:
|
||||
"""Test SearchController functionality with various filters."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_source_kind_filter(self):
|
||||
"""Test search filtering by source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_single_room_id(self):
|
||||
"""Test search filtering by single room ID (currently supported)."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(
|
||||
query_text="test",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_result_includes_available_fields(self, mock_db_result):
|
||||
"""Test that search results include available fields like source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
|
||||
class MockRow:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
self._mapping = data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.items())
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
mock_row = MockRow(mock_db_result)
|
||||
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
|
||||
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert total == 1
|
||||
assert len(results) == 1
|
||||
|
||||
result = results[0]
|
||||
assert isinstance(result, SearchResult)
|
||||
assert result.id == "test-transcript-id"
|
||||
assert result.title == "Test Transcript"
|
||||
assert result.rank == 0.95
|
||||
|
||||
|
||||
class TestSearchEndpointParsing:
|
||||
"""Test parameter parsing in the search endpoint."""
|
||||
|
||||
def test_parse_comma_separated_room_ids(self):
|
||||
"""Test parsing comma-separated room IDs."""
|
||||
room_ids_str = "room1,room2,room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1, room2 , room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1,,room3,"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room3"]
|
||||
|
||||
def test_parse_source_kind(self):
|
||||
"""Test parsing source_kind values."""
|
||||
for kind_str in ["live", "file", "room"]:
|
||||
parsed = SourceKind(kind_str)
|
||||
assert parsed == SourceKind(kind_str)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SourceKind("invalid_kind")
|
||||
|
||||
|
||||
class TestSearchResultModel:
|
||||
"""Test SearchResult model and serialization."""
|
||||
|
||||
def test_search_result_with_available_fields(self):
|
||||
"""Test SearchResult model with currently available fields populated."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
title="Test Title",
|
||||
user_id="user-123",
|
||||
room_id="room-456",
|
||||
source_kind=SourceKind.ROOM,
|
||||
created_at=datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.85,
|
||||
duration=1800.5,
|
||||
search_snippets=["snippet 1", "snippet 2"],
|
||||
)
|
||||
|
||||
assert result.id == "test-id"
|
||||
assert result.title == "Test Title"
|
||||
assert result.user_id == "user-123"
|
||||
assert result.room_id == "room-456"
|
||||
assert result.status == "completed"
|
||||
assert result.rank == 0.85
|
||||
assert result.duration == 1800.5
|
||||
assert len(result.search_snippets) == 2
|
||||
|
||||
def test_search_result_with_optional_fields_none(self):
|
||||
"""Test SearchResult model with optional fields as None."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.FILE,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
status="processing",
|
||||
rank=0.5,
|
||||
search_snippets=[],
|
||||
title=None,
|
||||
user_id=None,
|
||||
room_id=None,
|
||||
duration=None,
|
||||
)
|
||||
|
||||
assert result.title is None
|
||||
assert result.user_id is None
|
||||
assert result.room_id is None
|
||||
assert result.duration is None
|
||||
|
||||
def test_search_result_datetime_field(self):
|
||||
"""Test that SearchResult accepts datetime field."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.LIVE,
|
||||
created_at=datetime(2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.9,
|
||||
duration=None,
|
||||
search_snippets=[],
|
||||
)
|
||||
|
||||
assert result.created_at == datetime(
|
||||
2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
164
server/tests/test_search_long_summary.py
Normal file
164
server/tests/test_search_long_summary.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for long_summary in search functionality."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_snippet_prioritization():
|
||||
"""Test that snippets from long_summary are prioritized over webvtt content."""
|
||||
test_id = "test-snippet-priority-3f9a2b8c"
|
||||
|
||||
try:
|
||||
# Clean up any existing test data
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Snippet Priority",
|
||||
"title": "Meeting About Projects",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Project discussion",
|
||||
"long_summary": (
|
||||
"The team discussed advanced robotics applications including "
|
||||
"autonomous navigation systems and sensor fusion techniques. "
|
||||
"Robotics development will focus on real-time processing."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
We talked about many different topics today.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
The robotics project is making good progress.
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
We need to consider various implementation approaches.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for "robotics" which appears in both long_summary and webvtt
|
||||
params = SearchParameters(query_text="robotics")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result, "Should find the test transcript"
|
||||
|
||||
snippets = test_result.search_snippets
|
||||
assert len(snippets) > 0, "Should have at least one snippet"
|
||||
|
||||
# The first snippets should be from long_summary (more detailed content)
|
||||
first_snippet = snippets[0].lower()
|
||||
assert (
|
||||
"advanced robotics" in first_snippet or "autonomous" in first_snippet
|
||||
), f"First snippet should be from long_summary with detailed content. Got: {snippets[0]}"
|
||||
|
||||
# With max 3 snippets, we should get both from long_summary and webvtt
|
||||
assert len(snippets) <= 3, "Should respect max snippets limit"
|
||||
|
||||
# All snippets should contain the search term
|
||||
for snippet in snippets:
|
||||
assert (
|
||||
"robotics" in snippet.lower()
|
||||
), f"Snippet should contain search term: {snippet}"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_only_search():
|
||||
"""Test searching for content that only exists in long_summary."""
|
||||
test_id = "test-long-only-8b3c9f2a"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Only",
|
||||
"title": "Standard Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Team sync",
|
||||
"long_summary": (
|
||||
"Detailed analysis of cryptocurrency market trends and "
|
||||
"decentralized finance protocols. Discussion included "
|
||||
"yield farming strategies and liquidity pool mechanics."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Team meeting about general project updates.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
Discussion of timeline and deliverables.""",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for terms only in long_summary
|
||||
params = SearchParameters(query_text="cryptocurrency")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary-only content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
|
||||
# Verify the snippet is about cryptocurrency
|
||||
snippet = test_result.search_snippets[0].lower()
|
||||
assert "cryptocurrency" in snippet, "Snippet should contain the search term"
|
||||
|
||||
# Search for "yield farming" - a more specific term
|
||||
params2 = SearchParameters(query_text="yield farming")
|
||||
results2, total2 = await search_controller.search_transcripts(params2)
|
||||
|
||||
found2 = any(r.id == test_id for r in results2)
|
||||
assert found2, "Should find transcript by specific long_summary phrase"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
@@ -1,6 +1,10 @@
|
||||
"""Unit tests for search snippet generation."""
|
||||
|
||||
from reflector.db.search import SearchController
|
||||
from reflector.db.search import (
|
||||
SnippetCandidate,
|
||||
SnippetGenerator,
|
||||
WebVTTProcessor,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractWebVTT:
|
||||
@@ -16,7 +20,7 @@ class TestExtractWebVTT:
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
||||
"""
|
||||
result = SearchController._extract_webvtt_text(webvtt)
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
assert "Hello world, this is a test" in result
|
||||
assert "Indeed it is a test" in result
|
||||
assert "<v Speaker" not in result
|
||||
@@ -25,12 +29,11 @@ class TestExtractWebVTT:
|
||||
|
||||
def test_extract_empty_webvtt(self):
|
||||
"""Test empty WebVTT returns empty string."""
|
||||
assert SearchController._extract_webvtt_text("") == ""
|
||||
assert SearchController._extract_webvtt_text(None) == ""
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
|
||||
def test_extract_malformed_webvtt(self):
|
||||
"""Test malformed WebVTT returns empty string."""
|
||||
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
|
||||
result = WebVTTProcessor.extract_text("Not a valid WebVTT")
|
||||
assert result == ""
|
||||
|
||||
|
||||
@@ -39,8 +42,7 @@ class TestGenerateSnippets:
|
||||
|
||||
def test_multiple_matches(self):
|
||||
"""Test finding multiple occurrences of search term in long text."""
|
||||
# Create text with Python mentions far apart to get separate snippets
|
||||
separator = " This is filler text. " * 20 # ~400 chars of padding
|
||||
separator = " This is filler text. " * 20
|
||||
text = (
|
||||
"Python is great for machine learning."
|
||||
+ separator
|
||||
@@ -51,18 +53,16 @@ class TestGenerateSnippets:
|
||||
+ "The Python community is very supportive."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "Python")
|
||||
# With enough separation, we should get multiple snippets
|
||||
assert len(snippets) >= 2 # At least 2 distinct snippets
|
||||
snippets = SnippetGenerator.generate(text, "Python")
|
||||
assert len(snippets) >= 2
|
||||
|
||||
# Each snippet should contain "Python"
|
||||
for snippet in snippets:
|
||||
assert "python" in snippet.lower()
|
||||
|
||||
def test_single_match(self):
|
||||
"""Test single occurrence returns one snippet."""
|
||||
text = "This document discusses artificial intelligence and its applications."
|
||||
snippets = SearchController._generate_snippets(text, "artificial intelligence")
|
||||
snippets = SnippetGenerator.generate(text, "artificial intelligence")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "artificial intelligence" in snippets[0].lower()
|
||||
@@ -70,24 +70,22 @@ class TestGenerateSnippets:
|
||||
def test_no_matches(self):
|
||||
"""Test no matches returns empty list."""
|
||||
text = "This is some random text without the search term."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
assert snippets == []
|
||||
|
||||
def test_case_insensitive_search(self):
|
||||
"""Test search is case insensitive."""
|
||||
# Add enough text between matches to get separate snippets
|
||||
text = (
|
||||
"MACHINE LEARNING is important for modern applications. "
|
||||
+ "It requires lots of data and computational resources. " * 5 # Padding
|
||||
+ "It requires lots of data and computational resources. " * 5
|
||||
+ "Machine Learning rocks and transforms industries. "
|
||||
+ "Deep learning is a subset of it. " * 5 # More padding
|
||||
+ "Deep learning is a subset of it. " * 5
|
||||
+ "Finally, machine learning will shape our future."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
# Should find at least 2 (might be 3 if text is long enough)
|
||||
assert len(snippets) >= 2
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
@@ -95,61 +93,55 @@ class TestGenerateSnippets:
|
||||
def test_partial_match_fallback(self):
|
||||
"""Test fallback to first word when exact phrase not found."""
|
||||
text = "We use machine intelligence for processing."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
|
||||
# Should fall back to finding "machine"
|
||||
assert len(snippets) == 1
|
||||
assert "machine" in snippets[0].lower()
|
||||
|
||||
def test_snippet_ellipsis(self):
|
||||
"""Test ellipsis added for truncated snippets."""
|
||||
# Long text where match is in the middle
|
||||
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
||||
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
|
||||
snippets = SnippetGenerator.generate(text, "TARGET_WORD")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "..." in snippets[0] # Should have ellipsis
|
||||
assert "..." in snippets[0]
|
||||
assert "TARGET_WORD" in snippets[0]
|
||||
|
||||
def test_overlapping_snippets_deduplicated(self):
|
||||
"""Test overlapping matches don't create duplicate snippets."""
|
||||
text = "test test test word" * 10 # Repeated pattern
|
||||
snippets = SearchController._generate_snippets(text, "test")
|
||||
text = "test test test word" * 10
|
||||
snippets = SnippetGenerator.generate(text, "test")
|
||||
|
||||
# Should get unique snippets, not duplicates
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) == len(set(snippets)) # All unique
|
||||
assert len(snippets) == len(set(snippets))
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test empty text or search term returns empty list."""
|
||||
assert SearchController._generate_snippets("", "search") == []
|
||||
assert SearchController._generate_snippets("text", "") == []
|
||||
assert SearchController._generate_snippets("", "") == []
|
||||
assert SnippetGenerator.generate("", "search") == []
|
||||
assert SnippetGenerator.generate("text", "") == []
|
||||
assert SnippetGenerator.generate("", "") == []
|
||||
|
||||
def test_max_snippets_limit(self):
|
||||
"""Test respects max_snippets parameter."""
|
||||
# Create text with well-separated occurrences
|
||||
separator = " filler " * 50 # Ensure snippets don't overlap
|
||||
text = ("Python is amazing" + separator) * 10 # 10 occurrences
|
||||
separator = " filler " * 50
|
||||
text = ("Python is amazing" + separator) * 10
|
||||
|
||||
# Test with different limits
|
||||
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
|
||||
snippets_1 = SnippetGenerator.generate(text, "Python", max_snippets=1)
|
||||
assert len(snippets_1) == 1
|
||||
|
||||
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
|
||||
snippets_2 = SnippetGenerator.generate(text, "Python", max_snippets=2)
|
||||
assert len(snippets_2) == 2
|
||||
|
||||
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
|
||||
snippets_5 = SnippetGenerator.generate(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5
|
||||
|
||||
def test_snippet_length(self):
|
||||
"""Test snippet length is reasonable."""
|
||||
text = "word " * 200 # Long text
|
||||
snippets = SearchController._generate_snippets(text, "word")
|
||||
text = "word " * 200
|
||||
snippets = SnippetGenerator.generate(text, "word")
|
||||
|
||||
for snippet in snippets:
|
||||
# Default max_length is 150 + some context
|
||||
assert len(snippet) <= 200 # Some buffer for ellipsis
|
||||
assert len(snippet) <= 200
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
@@ -157,7 +149,6 @@ class TestFullPipeline:
|
||||
|
||||
def test_webvtt_to_snippets_integration(self):
|
||||
"""Test full pipeline from WebVTT to search snippets."""
|
||||
# Create WebVTT with well-separated content for multiple snippets
|
||||
webvtt = (
|
||||
"""WEBVTT
|
||||
|
||||
@@ -182,17 +173,362 @@ class TestFullPipeline:
|
||||
"""
|
||||
)
|
||||
|
||||
# Extract and generate snippets
|
||||
plain_text = SearchController._extract_webvtt_text(webvtt)
|
||||
snippets = SearchController._generate_snippets(plain_text, "machine learning")
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt)
|
||||
snippets = SnippetGenerator.generate(plain_text, "machine learning")
|
||||
|
||||
# Should find at least 2 snippets (text might still be close together)
|
||||
assert len(snippets) >= 1 # At minimum one snippet containing matches
|
||||
assert len(snippets) <= 3 # At most 3 by default
|
||||
assert len(snippets) >= 1
|
||||
assert len(snippets) <= 3
|
||||
|
||||
# No WebVTT artifacts in snippets
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
assert "<v Speaker" not in snippet
|
||||
assert "00:00" not in snippet
|
||||
assert "-->" not in snippet
|
||||
|
||||
|
||||
class TestMultiWordQueryBehavior:
|
||||
"""Tests for multi-word query behavior and exact phrase matching."""
|
||||
|
||||
def test_multi_word_query_snippet_behavior(self):
|
||||
"""Test that multi-word queries generate snippets based on exact phrase matching."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
user_snippets = SnippetGenerator.generate(sample_text, "user")
|
||||
assert len(user_snippets) == 2, "Should find 2 snippets for 'user'"
|
||||
|
||||
jordan_snippets = SnippetGenerator.generate(sample_text, "jordan")
|
||||
assert len(jordan_snippets) >= 1, "Should find at least 1 snippet for 'jordan'"
|
||||
|
||||
multi_word_snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
assert len(multi_word_snippets) == 1, (
|
||||
"Should return exactly 1 snippet for 'user jordan' "
|
||||
"(only the exact phrase match, not individual word occurrences)"
|
||||
)
|
||||
|
||||
snippet = multi_word_snippets[0]
|
||||
assert (
|
||||
"user jordan" in snippet.lower()
|
||||
), "The snippet should contain the exact phrase 'user jordan'"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), "The snippet should not include the first standalone 'user' with Alice"
|
||||
|
||||
def test_multi_word_query_without_exact_match(self):
|
||||
"""Test snippet generation when exact phrase is not found."""
|
||||
sample_text = """User Alice is here. Bob and jordan are talking.
|
||||
Later jordan mentions something. The user is happy."""
|
||||
|
||||
snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
|
||||
assert (
|
||||
len(snippets) >= 1
|
||||
), "Should find at least 1 snippet when falling back to first word"
|
||||
|
||||
all_snippets_text = " ".join(snippets).lower()
|
||||
assert (
|
||||
"user" in all_snippets_text
|
||||
), "Snippets should contain 'user' (the first word)"
|
||||
|
||||
def test_exact_phrase_at_text_boundaries(self):
|
||||
"""Test snippet generation when exact phrase appears at text boundaries."""
|
||||
|
||||
text_start = "user jordan started the meeting. Other content here."
|
||||
snippets = SnippetGenerator.generate(text_start, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
text_end = "Other content here. The meeting ended with user jordan"
|
||||
snippets = SnippetGenerator.generate(text_end, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
def test_multi_word_query_matches_words_appearing_separately_and_together(self):
|
||||
"""Test that multi-word queries prioritize exact phrase matches over individual word occurrences."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
search_query = "user jordan"
|
||||
snippets = SnippetGenerator.generate(sample_text, search_query)
|
||||
|
||||
assert len(snippets) == 1, (
|
||||
f"Expected exactly 1 snippet for '{search_query}' when exact phrase exists, "
|
||||
f"got {len(snippets)}. Should ignore individual word occurrences."
|
||||
)
|
||||
|
||||
snippet = snippets[0]
|
||||
|
||||
assert (
|
||||
search_query in snippet.lower()
|
||||
), f"Snippet should contain the exact phrase '{search_query}'. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"jordan mentions" in snippet.lower()
|
||||
), f"Snippet should include context before the exact phrase match. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), f"Snippet should not include separate occurrences of individual words. Got: {snippet}"
|
||||
|
||||
text_2 = """The alpha version was released.
|
||||
Beta testing started yesterday.
|
||||
The alpha beta integration is complete."""
|
||||
|
||||
snippets_2 = SnippetGenerator.generate(text_2, "alpha beta")
|
||||
assert len(snippets_2) == 1, "Should return 1 snippet for exact phrase match"
|
||||
assert "alpha beta" in snippets_2[0].lower(), "Should contain exact phrase"
|
||||
assert (
|
||||
"version" not in snippets_2[0].lower()
|
||||
), "Should not include first separate occurrence"
|
||||
|
||||
|
||||
class TestSnippetGenerationEnhanced:
|
||||
"""Additional snippet generation tests from test_search_enhancements.py."""
|
||||
|
||||
def test_snippet_generation_from_webvtt(self):
|
||||
"""Test snippet generation from WebVTT content."""
|
||||
webvtt_content = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:05.000
|
||||
This is the beginning of the transcript
|
||||
|
||||
00:00:05.000 --> 00:00:10.000
|
||||
The search term appears here in the middle
|
||||
|
||||
00:00:10.000 --> 00:00:15.000
|
||||
And this is the end of the content"""
|
||||
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt_content)
|
||||
snippets = SnippetGenerator.generate(plain_text, "search term")
|
||||
|
||||
assert len(snippets) > 0
|
||||
assert any("search term" in snippet.lower() for snippet in snippets)
|
||||
|
||||
def test_extract_webvtt_text_with_malformed_variations(self):
|
||||
"""Test WebVTT extraction with various malformed content."""
|
||||
malformed_vtt = "This is not valid WebVTT content"
|
||||
result = WebVTTProcessor.extract_text(malformed_vtt)
|
||||
assert result == ""
|
||||
|
||||
partial_vtt = "WEBVTT\nNo timestamps here"
|
||||
result = WebVTTProcessor.extract_text(partial_vtt)
|
||||
assert result == "" or "No timestamps" not in result
|
||||
|
||||
|
||||
class TestPureFunctions:
|
||||
"""Test the pure functions extracted for functional programming."""
|
||||
|
||||
def test_find_all_matches(self):
|
||||
"""Test finding all match positions in text."""
|
||||
text = "Python is great. Python is powerful. I love Python."
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Ruby"))
|
||||
assert matches == []
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches("", "test"))
|
||||
assert matches == []
|
||||
matches = list(SnippetGenerator.find_all_matches("test", ""))
|
||||
assert matches == []
|
||||
|
||||
def test_create_snippet(self):
|
||||
"""Test creating a snippet from a match position."""
|
||||
text = "This is a long text with the word Python in the middle and more text after."
|
||||
|
||||
snippet = SnippetGenerator.create_snippet(text, 35, max_length=150)
|
||||
assert "Python" in snippet.text()
|
||||
assert snippet.start >= 0
|
||||
assert snippet.end <= len(text)
|
||||
assert isinstance(snippet, SnippetCandidate)
|
||||
|
||||
assert len(snippet.text()) > 0
|
||||
assert snippet.start <= snippet.end
|
||||
|
||||
long_text = "A" * 200
|
||||
snippet = SnippetGenerator.create_snippet(long_text, 100, max_length=50)
|
||||
assert snippet.text().startswith("...")
|
||||
assert snippet.text().endswith("...")
|
||||
|
||||
snippet = SnippetGenerator.create_snippet("short text", 0, max_length=100)
|
||||
assert snippet.start == 0
|
||||
assert "short text" in snippet.text()
|
||||
|
||||
def test_filter_non_overlapping(self):
|
||||
"""Test filtering overlapping snippets."""
|
||||
candidates = [
|
||||
SnippetCandidate(_text="First snippet", start=0, _original_text_length=100),
|
||||
SnippetCandidate(_text="Overlapping", start=10, _original_text_length=100),
|
||||
SnippetCandidate(
|
||||
_text="Third snippet", start=40, _original_text_length=100
|
||||
),
|
||||
SnippetCandidate(
|
||||
_text="Fourth snippet", start=65, _original_text_length=100
|
||||
),
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter(candidates)))
|
||||
assert filtered == [
|
||||
"First snippet...",
|
||||
"...Third snippet...",
|
||||
"...Fourth snippet...",
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter([])))
|
||||
assert filtered == []
|
||||
|
||||
def test_generate_integration(self):
|
||||
"""Test the main SnippetGenerator.generate function."""
|
||||
text = "Machine learning is amazing. Machine learning transforms data. Learn machine learning today."
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
assert len(snippets) <= 3
|
||||
assert all("machine learning" in s.lower() for s in snippets)
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning", max_snippets=2)
|
||||
assert len(snippets) <= 2
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine vision")
|
||||
assert len(snippets) > 0
|
||||
assert any("machine" in s.lower() for s in snippets)
|
||||
|
||||
def test_extract_webvtt_text_basic(self):
|
||||
"""Test WebVTT text extraction (basic test, full tests exist elsewhere)."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Hello world
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
This is a test"""
|
||||
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
assert "Hello world" in result
|
||||
assert "This is a test" in result
|
||||
|
||||
# Test empty input
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
assert WebVTTProcessor.extract_text(None) == ""
|
||||
|
||||
def test_generate_webvtt_snippets(self):
|
||||
"""Test generating snippets from WebVTT content."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Python programming is great
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Learn Python today"""
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets(webvtt, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets("", "Python")
|
||||
assert snippets == []
|
||||
|
||||
def test_from_summary(self):
|
||||
"""Test generating snippets from summary text."""
|
||||
summary = "This meeting discussed Python development and machine learning applications."
|
||||
|
||||
snippets = SnippetGenerator.from_summary(summary, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
long_summary = "Python " * 20
|
||||
snippets = SnippetGenerator.from_summary(long_summary, "Python")
|
||||
assert len(snippets) <= 2
|
||||
|
||||
def test_combine_sources(self):
|
||||
"""Test combining snippets from multiple sources."""
|
||||
summary = "Python is a great programming language."
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Learn Python programming
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Python is powerful"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) > 0
|
||||
assert total_count > 0
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert all("Python" in s for s in snippets)
|
||||
assert total_count == 1
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert total_count == 2
|
||||
|
||||
long_summary = "Python " * 10
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, "Python", max_total=2
|
||||
)
|
||||
assert len(snippets) == 2
|
||||
assert total_count >= 10
|
||||
|
||||
def test_match_counting_sum_logic(self):
|
||||
"""Test that match counting correctly sums matches from both sources."""
|
||||
summary = "data science uses data analysis and data mining techniques"
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Big data processing
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
data visualization and data storage"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "data", max_total=3
|
||||
)
|
||||
assert total_count == 6
|
||||
assert len(snippets) <= 3
|
||||
|
||||
summary_snippets, summary_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "data", max_total=3
|
||||
)
|
||||
assert summary_count == 3
|
||||
|
||||
webvtt_snippets, webvtt_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "data", max_total=3
|
||||
)
|
||||
assert webvtt_count == 3
|
||||
|
||||
snippets_empty, count_empty = SnippetGenerator.combine_sources(
|
||||
None, None, "data", max_total=3
|
||||
)
|
||||
assert snippets_empty == []
|
||||
assert count_empty == 0
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for the pure functions."""
|
||||
text = "Test with special: @#$%^&*() characters"
|
||||
snippets = SnippetGenerator.generate(text, "@#$%")
|
||||
assert len(snippets) > 0
|
||||
|
||||
long_query = "a" * 100
|
||||
snippets = SnippetGenerator.generate("Some text", long_query)
|
||||
assert snippets == []
|
||||
|
||||
text = "Unicode test: café, naïve, 日本語"
|
||||
snippets = SnippetGenerator.generate(text, "café")
|
||||
assert len(snippets) > 0
|
||||
assert "café" in snippets[0]
|
||||
|
||||
Reference in New Issue
Block a user