mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: search backend (#537)
* docs: transient docs * chore: cleanup * webvtt WIP * webvtt field * chore: webvtt tests comments * chore: remove useless tests * feat: search TASK.md * feat: full text search by title/webvtt * chore: search api task * feat: search api * feat: search API * chore: rm task md * chore: roll back unnecessary validators * chore: pr review WIP * chore: pr review WIP * chore: pr review * chore: top imports * feat: better lint + ci * feat: better lint + ci * feat: better lint + ci * feat: better lint + ci * chore: lint * chore: lint * fix: db datetime definitions * fix: flush() params * fix: update transcript mutability expectation / test * fix: update transcript mutability expectation / test * chore: auto review * chore: new controller extraction * chore: new controller extraction * chore: cleanup * chore: review WIP * chore: pr WIP * chore: remove ci lint * chore: openapi regeneration * chore: openapi regeneration * chore: postgres test doc * fix: .dockerignore for arm binaries * fix: .dockerignore for arm binaries * fix: cap test loops * fix: cap test loops * fix: cap test loops * fix: get_transcript_topics * chore: remove flow.md docs and claude guidance * chore: remove claude.md db doc * chore: remove claude.md db doc * chore: remove claude.md db doc * chore: remove claude.md db doc
This commit is contained in:
163
server/tests/test_search.py
Normal file
163
server/tests/test_search.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for full-text search functionality."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_postgresql_only():
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
params = SearchParameters(query_text="any query here")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" ")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_input_validation():
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" \t\n ")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgresql_search_with_data():
|
||||
"""Test full-text search with actual data in PostgreSQL.
|
||||
|
||||
Example how to run: DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector_test uv run pytest tests/test_search.py::test_postgresql_search_with_data -v -p no:env
|
||||
"""
|
||||
# Skip if not PostgreSQL
|
||||
if not is_postgresql():
|
||||
pytest.skip("Test requires PostgreSQL. Set DATABASE_URL=postgresql://...")
|
||||
|
||||
await database.connect()
|
||||
|
||||
# collision is improbable
|
||||
test_id = "test-search-e2e-7f3a9b2c"
|
||||
|
||||
try:
|
||||
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Search Transcript",
|
||||
"title": "Engineering Planning Meeting Q4 2024",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(),
|
||||
"short_summary": "Team discussed search implementation",
|
||||
"long_summary": "The engineering team met to plan the search feature",
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Welcome to our engineering planning meeting for Q4 2024.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
Today we'll discuss the implementation of full-text search.
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
The search feature should support complex queries with ranking.
|
||||
|
||||
00:00:30.000 --> 00:00:40.000
|
||||
We need to implement PostgreSQL tsvector for better performance.""",
|
||||
}
|
||||
|
||||
await database.execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Test 1: Search for a word in title
|
||||
params = SearchParameters(query_text="planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by title word"
|
||||
|
||||
# Test 2: Search for a word in webvtt content
|
||||
params = SearchParameters(query_text="tsvector")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by webvtt content"
|
||||
|
||||
# Test 3: Search with multiple words
|
||||
params = SearchParameters(query_text="engineering planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by multiple words"
|
||||
|
||||
# Test 4: Verify SearchResult structure
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
if test_result:
|
||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||
assert test_result.status == "completed"
|
||||
assert test_result.duration == 1800.0
|
||||
assert test_result.source_kind == "room"
|
||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||
|
||||
# Test 5: Search with OR operator
|
||||
params = SearchParameters(query_text="tsvector OR nosuchword")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript with OR query"
|
||||
|
||||
# Test 6: Quoted phrase search
|
||||
params = SearchParameters(query_text='"full-text search"')
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by exact phrase"
|
||||
|
||||
finally:
|
||||
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
|
||||
await database.disconnect()
|
||||
198
server/tests/test_search_snippets.py
Normal file
198
server/tests/test_search_snippets.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Unit tests for search snippet generation."""
|
||||
|
||||
from reflector.db.search import SearchController
|
||||
|
||||
|
||||
class TestExtractWebVTT:
|
||||
"""Test WebVTT text extraction."""
|
||||
|
||||
def test_extract_webvtt_with_speakers(self):
|
||||
"""Test extraction removes speaker tags and timestamps."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
<v Speaker0>Hello world, this is a test.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
||||
"""
|
||||
result = SearchController._extract_webvtt_text(webvtt)
|
||||
assert "Hello world, this is a test" in result
|
||||
assert "Indeed it is a test" in result
|
||||
assert "<v Speaker" not in result
|
||||
assert "00:00" not in result
|
||||
assert "-->" not in result
|
||||
|
||||
def test_extract_empty_webvtt(self):
|
||||
"""Test empty WebVTT returns empty string."""
|
||||
assert SearchController._extract_webvtt_text("") == ""
|
||||
assert SearchController._extract_webvtt_text(None) == ""
|
||||
|
||||
def test_extract_malformed_webvtt(self):
|
||||
"""Test malformed WebVTT returns empty string."""
|
||||
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestGenerateSnippets:
|
||||
"""Test snippet generation from plain text."""
|
||||
|
||||
def test_multiple_matches(self):
|
||||
"""Test finding multiple occurrences of search term in long text."""
|
||||
# Create text with Python mentions far apart to get separate snippets
|
||||
separator = " This is filler text. " * 20 # ~400 chars of padding
|
||||
text = (
|
||||
"Python is great for machine learning."
|
||||
+ separator
|
||||
+ "Many companies use Python for data science."
|
||||
+ separator
|
||||
+ "Python has excellent libraries for analysis."
|
||||
+ separator
|
||||
+ "The Python community is very supportive."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "Python")
|
||||
# With enough separation, we should get multiple snippets
|
||||
assert len(snippets) >= 2 # At least 2 distinct snippets
|
||||
|
||||
# Each snippet should contain "Python"
|
||||
for snippet in snippets:
|
||||
assert "python" in snippet.lower()
|
||||
|
||||
def test_single_match(self):
|
||||
"""Test single occurrence returns one snippet."""
|
||||
text = "This document discusses artificial intelligence and its applications."
|
||||
snippets = SearchController._generate_snippets(text, "artificial intelligence")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "artificial intelligence" in snippets[0].lower()
|
||||
|
||||
def test_no_matches(self):
|
||||
"""Test no matches returns empty list."""
|
||||
text = "This is some random text without the search term."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
assert snippets == []
|
||||
|
||||
def test_case_insensitive_search(self):
|
||||
"""Test search is case insensitive."""
|
||||
# Add enough text between matches to get separate snippets
|
||||
text = (
|
||||
"MACHINE LEARNING is important for modern applications. "
|
||||
+ "It requires lots of data and computational resources. " * 5 # Padding
|
||||
+ "Machine Learning rocks and transforms industries. "
|
||||
+ "Deep learning is a subset of it. " * 5 # More padding
|
||||
+ "Finally, machine learning will shape our future."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
# Should find at least 2 (might be 3 if text is long enough)
|
||||
assert len(snippets) >= 2
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
|
||||
def test_partial_match_fallback(self):
|
||||
"""Test fallback to first word when exact phrase not found."""
|
||||
text = "We use machine intelligence for processing."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
# Should fall back to finding "machine"
|
||||
assert len(snippets) == 1
|
||||
assert "machine" in snippets[0].lower()
|
||||
|
||||
def test_snippet_ellipsis(self):
|
||||
"""Test ellipsis added for truncated snippets."""
|
||||
# Long text where match is in the middle
|
||||
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
||||
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "..." in snippets[0] # Should have ellipsis
|
||||
assert "TARGET_WORD" in snippets[0]
|
||||
|
||||
def test_overlapping_snippets_deduplicated(self):
|
||||
"""Test overlapping matches don't create duplicate snippets."""
|
||||
text = "test test test word" * 10 # Repeated pattern
|
||||
snippets = SearchController._generate_snippets(text, "test")
|
||||
|
||||
# Should get unique snippets, not duplicates
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) == len(set(snippets)) # All unique
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test empty text or search term returns empty list."""
|
||||
assert SearchController._generate_snippets("", "search") == []
|
||||
assert SearchController._generate_snippets("text", "") == []
|
||||
assert SearchController._generate_snippets("", "") == []
|
||||
|
||||
def test_max_snippets_limit(self):
|
||||
"""Test respects max_snippets parameter."""
|
||||
# Create text with well-separated occurrences
|
||||
separator = " filler " * 50 # Ensure snippets don't overlap
|
||||
text = ("Python is amazing" + separator) * 10 # 10 occurrences
|
||||
|
||||
# Test with different limits
|
||||
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
|
||||
assert len(snippets_1) == 1
|
||||
|
||||
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
|
||||
assert len(snippets_2) == 2
|
||||
|
||||
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
|
||||
|
||||
def test_snippet_length(self):
|
||||
"""Test snippet length is reasonable."""
|
||||
text = "word " * 200 # Long text
|
||||
snippets = SearchController._generate_snippets(text, "word")
|
||||
|
||||
for snippet in snippets:
|
||||
# Default max_length is 150 + some context
|
||||
assert len(snippet) <= 200 # Some buffer for ellipsis
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
"""Test the complete WebVTT to snippets pipeline."""
|
||||
|
||||
def test_webvtt_to_snippets_integration(self):
|
||||
"""Test full pipeline from WebVTT to search snippets."""
|
||||
# Create WebVTT with well-separated content for multiple snippets
|
||||
webvtt = (
|
||||
"""WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
<v Speaker0>Let's discuss machine learning applications in modern technology.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>"""
|
||||
+ "Various industries are adopting new technologies. " * 10
|
||||
+ """
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
<v Speaker2>Machine learning is revolutionizing healthcare and diagnostics.
|
||||
|
||||
00:00:30.000 --> 00:00:40.000
|
||||
<v Speaker3>"""
|
||||
+ "Financial markets show interesting patterns. " * 10
|
||||
+ """
|
||||
|
||||
00:00:40.000 --> 00:00:50.000
|
||||
<v Speaker0>Machine learning in education provides personalized experiences.
|
||||
"""
|
||||
)
|
||||
|
||||
# Extract and generate snippets
|
||||
plain_text = SearchController._extract_webvtt_text(webvtt)
|
||||
snippets = SearchController._generate_snippets(plain_text, "machine learning")
|
||||
|
||||
# Should find at least 2 snippets (text might still be close together)
|
||||
assert len(snippets) >= 1 # At minimum one snippet containing matches
|
||||
assert len(snippets) <= 3 # At most 3 by default
|
||||
|
||||
# No WebVTT artifacts in snippets
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
assert "<v Speaker" not in snippet
|
||||
assert "00:00" not in snippet
|
||||
assert "-->" not in snippet
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
@@ -39,14 +40,18 @@ async def test_transcript_process(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish
|
||||
while True:
|
||||
# wait for processing to finish (max 10 minutes)
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
|
||||
|
||||
# restart the processing
|
||||
response = await ac.post(
|
||||
@@ -55,14 +60,18 @@ async def test_transcript_process(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish
|
||||
while True:
|
||||
# wait for processing to finish (max 10 minutes)
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -21,14 +22,31 @@ class ThreadedUvicorn:
|
||||
|
||||
async def start(self):
|
||||
self.thread.start()
|
||||
while not self.server.started:
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (
|
||||
not self.server.started
|
||||
and (time.monotonic() - start_time) < timeout_seconds
|
||||
):
|
||||
await asyncio.sleep(0.1)
|
||||
if not self.server.started:
|
||||
raise TimeoutError(
|
||||
f"Server failed to start after {timeout_seconds} seconds"
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
if self.thread.is_alive():
|
||||
self.server.should_exit = True
|
||||
while self.thread.is_alive():
|
||||
continue
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.time()
|
||||
while (
|
||||
self.thread.is_alive() and (time.time() - start_time) < timeout_seconds
|
||||
):
|
||||
time.sleep(0.1)
|
||||
if self.thread.is_alive():
|
||||
raise TimeoutError(
|
||||
f"Thread failed to stop after {timeout_seconds} seconds"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -92,12 +110,16 @@ async def test_transcript_rtc_and_websocket(
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
|
||||
print("Test websocket: CONNECTED")
|
||||
try:
|
||||
while True:
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
msg = await ws.receive_json()
|
||||
print(f"Test websocket: JSON {msg}")
|
||||
if msg is None:
|
||||
break
|
||||
events.append(msg)
|
||||
else:
|
||||
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
|
||||
except Exception as e:
|
||||
print(f"Test websocket: EXCEPTION {e}")
|
||||
finally:
|
||||
@@ -145,9 +167,12 @@ async def test_transcript_rtc_and_websocket(
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
raise TimeoutError("Transcript processing failed")
|
||||
|
||||
# stop websocket task
|
||||
websocket_task.cancel()
|
||||
@@ -253,12 +278,16 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
|
||||
print("Test websocket: CONNECTED")
|
||||
try:
|
||||
while True:
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
msg = await ws.receive_json()
|
||||
print(f"Test websocket: JSON {msg}")
|
||||
if msg is None:
|
||||
break
|
||||
events.append(msg)
|
||||
else:
|
||||
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
|
||||
except Exception as e:
|
||||
print(f"Test websocket: EXCEPTION {e}")
|
||||
finally:
|
||||
@@ -310,9 +339,12 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
if resp.json()["status"] == "ended":
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
raise TimeoutError("Transcript processing failed")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
@@ -39,14 +40,18 @@ async def test_transcript_upload_file(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait the processing to finish
|
||||
while True:
|
||||
# wait the processing to finish (max 10 minutes)
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
|
||||
151
server/tests/test_webvtt.py
Normal file
151
server/tests/test_webvtt.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for WebVTT utilities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.processors.types import Transcript, Word, words_to_segments
|
||||
from reflector.utils.webvtt import topics_to_webvtt, words_to_webvtt
|
||||
|
||||
|
||||
class TestWordsToWebVTT:
|
||||
"""Test words_to_webvtt function with TDD approach."""
|
||||
|
||||
def test_empty_words_returns_empty_webvtt(self):
|
||||
"""Should return empty WebVTT structure for empty words list."""
|
||||
|
||||
result = words_to_webvtt([])
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert result.strip() == "WEBVTT"
|
||||
|
||||
def test_single_word_creates_single_caption(self):
|
||||
"""Should create one caption for a single word."""
|
||||
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert "00:00:00.000 --> 00:00:01.000" in result
|
||||
assert "Hello" in result
|
||||
assert "<v Speaker0>" in result
|
||||
|
||||
def test_multiple_words_same_speaker_groups_properly(self):
|
||||
"""Should group consecutive words from same speaker."""
|
||||
|
||||
words = [
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" world", start=0.5, end=1.0, speaker=0),
|
||||
]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert "Hello world" in result
|
||||
assert "<v Speaker0>" in result
|
||||
|
||||
def test_speaker_change_creates_new_caption(self):
|
||||
"""Should create new caption when speaker changes."""
|
||||
|
||||
words = [
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text="Hi", start=0.6, end=1.0, speaker=1),
|
||||
]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
lines = result.split("\n")
|
||||
assert "<v Speaker0>" in result
|
||||
assert "<v Speaker1>" in result
|
||||
assert "Hello" in result
|
||||
assert "Hi" in result
|
||||
|
||||
def test_punctuation_creates_segment_boundary(self):
|
||||
"""Should respect punctuation boundaries from segmentation logic."""
|
||||
|
||||
words = [
|
||||
Word(text="Hello.", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" How", start=0.6, end=1.0, speaker=0),
|
||||
Word(text=" are", start=1.0, end=1.3, speaker=0),
|
||||
Word(text=" you?", start=1.3, end=1.8, speaker=0),
|
||||
]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert "<v Speaker0>" in result
|
||||
|
||||
|
||||
class TestTopicsToWebVTT:
|
||||
"""Test topics_to_webvtt function."""
|
||||
|
||||
def test_empty_topics_returns_empty_webvtt(self):
|
||||
"""Should handle empty topics list."""
|
||||
|
||||
result = topics_to_webvtt([])
|
||||
assert "WEBVTT" in result
|
||||
assert result.strip() == "WEBVTT"
|
||||
|
||||
def test_extracts_words_from_topics(self):
|
||||
"""Should extract all words from topics in sequence."""
|
||||
|
||||
class MockTopic:
|
||||
def __init__(self, words):
|
||||
self.words = words
|
||||
|
||||
topics = [
|
||||
MockTopic(
|
||||
[
|
||||
Word(text="First", start=0.0, end=0.5, speaker=1),
|
||||
Word(text="Second", start=1.0, end=1.5, speaker=0),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
result = topics_to_webvtt(topics)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
first_pos = result.find("First")
|
||||
second_pos = result.find("Second")
|
||||
assert first_pos < second_pos
|
||||
|
||||
def test_non_sequential_topics_raises_assertion(self):
|
||||
"""Should raise assertion error when words are not in chronological sequence."""
|
||||
|
||||
class MockTopic:
|
||||
def __init__(self, words):
|
||||
self.words = words
|
||||
|
||||
topics = [
|
||||
MockTopic(
|
||||
[
|
||||
Word(text="Second", start=1.0, end=1.5, speaker=0),
|
||||
Word(text="First", start=0.0, end=0.5, speaker=1),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
topics_to_webvtt(topics)
|
||||
|
||||
assert "Words are not in sequence" in str(exc_info.value)
|
||||
assert "Second and First" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestTranscriptWordsToSegments:
|
||||
"""Test static words_to_segments method (TDD for making it static)."""
|
||||
|
||||
def test_static_method_exists(self):
|
||||
"""Should have static words_to_segments method."""
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
segments = words_to_segments(words)
|
||||
|
||||
assert isinstance(segments, list)
|
||||
assert len(segments) == 1
|
||||
assert segments[0].text == "Hello"
|
||||
assert segments[0].speaker == 0
|
||||
|
||||
def test_backward_compatibility(self):
|
||||
"""Should maintain backward compatibility with instance method."""
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
transcript = Transcript(words=words)
|
||||
|
||||
segments = transcript.as_segments()
|
||||
assert isinstance(segments, list)
|
||||
assert len(segments) == 1
|
||||
assert segments[0].text == "Hello"
|
||||
34
server/tests/test_webvtt_implementation.py
Normal file
34
server/tests/test_webvtt_implementation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Test WebVTT auto-update functionality and edge cases."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptController,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestWebVTTAutoUpdateImplementation:
|
||||
async def test_handle_topics_update_handles_dict_conversion(self):
|
||||
"""
|
||||
Verify that _handle_topics_update() properly converts dict data to TranscriptTopic objects.
|
||||
"""
|
||||
values = {
|
||||
"topics": [
|
||||
{
|
||||
"id": "topic1",
|
||||
"title": "Test",
|
||||
"summary": "Test",
|
||||
"timestamp": 0.0,
|
||||
"words": [
|
||||
{"text": "Hello", "start": 0.0, "end": 1.0, "speaker": 0}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
updated_values = TranscriptController._handle_topics_update(values)
|
||||
|
||||
assert "webvtt" in updated_values
|
||||
assert updated_values["webvtt"] is not None
|
||||
assert "WEBVTT" in updated_values["webvtt"]
|
||||
234
server/tests/test_webvtt_integration.py
Normal file
234
server/tests/test_webvtt_integration.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
TranscriptController,
|
||||
TranscriptTopic,
|
||||
transcripts,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestWebVTTAutoUpdate:
|
||||
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
|
||||
|
||||
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self):
|
||||
"""WebVTT should be None when creating transcript without topics."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["webvtt"] is None
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_updated_on_upsert_topic(self):
|
||||
"""WebVTT should update when upserting topics via upsert_topic method."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Test Topic",
|
||||
summary="Test summary",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" world", start=0.5, end=1.0, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
await controller.upsert_topic(transcript, topic)
|
||||
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
assert "Hello world" in webvtt
|
||||
assert "<v Speaker0>" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_updated_on_direct_topics_update(self):
|
||||
"""WebVTT should update when updating topics field directly."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topics_data = [
|
||||
{
|
||||
"id": "topic1",
|
||||
"title": "First Topic",
|
||||
"summary": "First sentence test",
|
||||
"timestamp": 0.0,
|
||||
"words": [
|
||||
{"text": "First", "start": 0.0, "end": 0.5, "speaker": 0},
|
||||
{"text": " sentence", "start": 0.5, "end": 1.0, "speaker": 0},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
await controller.update(transcript, {"topics": topics_data})
|
||||
|
||||
# Fetch from DB
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
assert "First sentence" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_updated_manually_with_handle_topics_update(self):
|
||||
"""Test that _handle_topics_update works when called manually."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic1 = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Topic 1",
|
||||
summary="Manual test",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Manual", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" test", start=0.5, end=1.0, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
transcript.upsert_topic(topic1)
|
||||
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
await controller.update(transcript, values)
|
||||
|
||||
# Fetch from DB
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
assert "Manual test" in webvtt
|
||||
assert "<v Speaker0>" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_update_with_non_sequential_topics_fails(self):
|
||||
"""Test that non-sequential topics raise assertion error."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic1 = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Bad Topic",
|
||||
summary="Bad order test",
|
||||
timestamp=1.0,
|
||||
words=[
|
||||
Word(text="Second", start=2.0, end=2.5, speaker=0),
|
||||
Word(text="First", start=1.0, end=1.5, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
transcript.upsert_topic(topic1)
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
TranscriptController._handle_topics_update(values)
|
||||
|
||||
assert "Words are not in sequence" in str(exc_info.value)
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_multiple_speakers_in_webvtt(self):
|
||||
"""Test WebVTT generation with multiple speakers."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Multi Speaker",
|
||||
summary="Multi speaker test",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text="Hi", start=1.0, end=1.5, speaker=1),
|
||||
Word(text="Goodbye", start=2.0, end=2.5, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
transcript.upsert_topic(topic)
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
await controller.update(transcript, values)
|
||||
|
||||
# Fetch from DB
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "<v Speaker0>" in webvtt
|
||||
assert "<v Speaker1>" in webvtt
|
||||
assert "Hello" in webvtt
|
||||
assert "Hi" in webvtt
|
||||
assert "Goodbye" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
Reference in New Issue
Block a user