feat: search backend (#537)

* docs: transient docs

* chore: cleanup

* webvtt WIP

* webvtt field

* chore: webvtt tests comments

* chore: remove useless tests

* feat: search TASK.md

* feat: full text search by title/webvtt

* chore: search api task

* feat: search api

* feat: search API

* chore: rm task md

* chore: roll back unnecessary validators

* chore: pr review WIP

* chore: pr review WIP

* chore: pr review

* chore: top imports

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* chore: lint

* chore: lint

* fix: db datetime definitions

* fix: flush() params

* fix: update transcript mutability expectation / test

* fix: update transcript mutability expectation / test

* chore: auto review

* chore: new controller extraction

* chore: new controller extraction

* chore: cleanup

* chore: review WIP

* chore: pr WIP

* chore: remove ci lint

* chore: openapi regeneration

* chore: openapi regeneration

* chore: postgres test doc

* fix: .dockerignore for arm binaries

* fix: .dockerignore for arm binaries

* fix: cap test loops

* fix: cap test loops

* fix: cap test loops

* fix: get_transcript_topics

* chore: remove flow.md docs and claude guidance

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc
This commit is contained in:
Igor Loskutov
2025-08-13 10:03:38 -04:00
committed by GitHub
parent a42ed12982
commit 6fb5cb21c2
29 changed files with 3213 additions and 1493 deletions

163
server/tests/test_search.py Normal file
View File

@@ -0,0 +1,163 @@
"""Tests for full-text search functionality."""
import json
from datetime import datetime
import pytest
from pydantic import ValidationError
from reflector.db import database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
from reflector.db.utils import is_postgresql
@pytest.mark.asyncio
async def test_search_postgresql_only():
await database.connect()
try:
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params)
assert results == []
assert total == 0
try:
SearchParameters(query_text="")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" ")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
finally:
await database.disconnect()
@pytest.mark.asyncio
async def test_search_input_validation():
await database.connect()
try:
try:
SearchParameters(query_text="")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" \t\n ")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
finally:
await database.disconnect()
@pytest.mark.asyncio
async def test_postgresql_search_with_data():
"""Test full-text search with actual data in PostgreSQL.
Example how to run: DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector_test uv run pytest tests/test_search.py::test_postgresql_search_with_data -v -p no:env
"""
# Skip if not PostgreSQL
if not is_postgresql():
pytest.skip("Test requires PostgreSQL. Set DATABASE_URL=postgresql://...")
await database.connect()
# collision is improbable
test_id = "test-search-e2e-7f3a9b2c"
try:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
test_data = {
"id": test_id,
"name": "Test Search Transcript",
"title": "Engineering Planning Meeting Q4 2024",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(),
"short_summary": "Team discussed search implementation",
"long_summary": "The engineering team met to plan the search feature",
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
Welcome to our engineering planning meeting for Q4 2024.
00:00:10.000 --> 00:00:20.000
Today we'll discuss the implementation of full-text search.
00:00:20.000 --> 00:00:30.000
The search feature should support complex queries with ranking.
00:00:30.000 --> 00:00:40.000
We need to implement PostgreSQL tsvector for better performance.""",
}
await database.execute(transcripts.insert().values(**test_data))
# Test 1: Search for a word in title
params = SearchParameters(query_text="planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
# Test 2: Search for a word in webvtt content
params = SearchParameters(query_text="tsvector")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
# Test 3: Search with multiple words
params = SearchParameters(query_text="engineering planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
# Test 4: Verify SearchResult structure
test_result = next((r for r in results if r.id == test_id), None)
if test_result:
assert test_result.title == "Engineering Planning Meeting Q4 2024"
assert test_result.status == "completed"
assert test_result.duration == 1800.0
assert test_result.source_kind == "room"
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
# Test 5: Search with OR operator
params = SearchParameters(query_text="tsvector OR nosuchword")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
# Test 6: Quoted phrase search
params = SearchParameters(query_text='"full-text search"')
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase"
finally:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
await database.disconnect()

View File

@@ -0,0 +1,198 @@
"""Unit tests for search snippet generation."""
from reflector.db.search import SearchController
class TestExtractWebVTT:
"""Test WebVTT text extraction."""
def test_extract_webvtt_with_speakers(self):
"""Test extraction removes speaker tags and timestamps."""
webvtt = """WEBVTT
00:00:00.000 --> 00:00:10.000
<v Speaker0>Hello world, this is a test.
00:00:10.000 --> 00:00:20.000
<v Speaker1>Indeed it is a test of WebVTT parsing.
"""
result = SearchController._extract_webvtt_text(webvtt)
assert "Hello world, this is a test" in result
assert "Indeed it is a test" in result
assert "<v Speaker" not in result
assert "00:00" not in result
assert "-->" not in result
def test_extract_empty_webvtt(self):
"""Test empty WebVTT returns empty string."""
assert SearchController._extract_webvtt_text("") == ""
assert SearchController._extract_webvtt_text(None) == ""
def test_extract_malformed_webvtt(self):
"""Test malformed WebVTT returns empty string."""
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
assert result == ""
class TestGenerateSnippets:
"""Test snippet generation from plain text."""
def test_multiple_matches(self):
"""Test finding multiple occurrences of search term in long text."""
# Create text with Python mentions far apart to get separate snippets
separator = " This is filler text. " * 20 # ~400 chars of padding
text = (
"Python is great for machine learning."
+ separator
+ "Many companies use Python for data science."
+ separator
+ "Python has excellent libraries for analysis."
+ separator
+ "The Python community is very supportive."
)
snippets = SearchController._generate_snippets(text, "Python")
# With enough separation, we should get multiple snippets
assert len(snippets) >= 2 # At least 2 distinct snippets
# Each snippet should contain "Python"
for snippet in snippets:
assert "python" in snippet.lower()
def test_single_match(self):
"""Test single occurrence returns one snippet."""
text = "This document discusses artificial intelligence and its applications."
snippets = SearchController._generate_snippets(text, "artificial intelligence")
assert len(snippets) == 1
assert "artificial intelligence" in snippets[0].lower()
def test_no_matches(self):
"""Test no matches returns empty list."""
text = "This is some random text without the search term."
snippets = SearchController._generate_snippets(text, "machine learning")
assert snippets == []
def test_case_insensitive_search(self):
"""Test search is case insensitive."""
# Add enough text between matches to get separate snippets
text = (
"MACHINE LEARNING is important for modern applications. "
+ "It requires lots of data and computational resources. " * 5 # Padding
+ "Machine Learning rocks and transforms industries. "
+ "Deep learning is a subset of it. " * 5 # More padding
+ "Finally, machine learning will shape our future."
)
snippets = SearchController._generate_snippets(text, "machine learning")
# Should find at least 2 (might be 3 if text is long enough)
assert len(snippets) >= 2
for snippet in snippets:
assert "machine learning" in snippet.lower()
def test_partial_match_fallback(self):
"""Test fallback to first word when exact phrase not found."""
text = "We use machine intelligence for processing."
snippets = SearchController._generate_snippets(text, "machine learning")
# Should fall back to finding "machine"
assert len(snippets) == 1
assert "machine" in snippets[0].lower()
def test_snippet_ellipsis(self):
"""Test ellipsis added for truncated snippets."""
# Long text where match is in the middle
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
assert len(snippets) == 1
assert "..." in snippets[0] # Should have ellipsis
assert "TARGET_WORD" in snippets[0]
def test_overlapping_snippets_deduplicated(self):
"""Test overlapping matches don't create duplicate snippets."""
text = "test test test word" * 10 # Repeated pattern
snippets = SearchController._generate_snippets(text, "test")
# Should get unique snippets, not duplicates
assert len(snippets) <= 3
assert len(snippets) == len(set(snippets)) # All unique
def test_empty_inputs(self):
"""Test empty text or search term returns empty list."""
assert SearchController._generate_snippets("", "search") == []
assert SearchController._generate_snippets("text", "") == []
assert SearchController._generate_snippets("", "") == []
def test_max_snippets_limit(self):
"""Test respects max_snippets parameter."""
# Create text with well-separated occurrences
separator = " filler " * 50 # Ensure snippets don't overlap
text = ("Python is amazing" + separator) * 10 # 10 occurrences
# Test with different limits
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
assert len(snippets_1) == 1
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
assert len(snippets_2) == 2
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
def test_snippet_length(self):
"""Test snippet length is reasonable."""
text = "word " * 200 # Long text
snippets = SearchController._generate_snippets(text, "word")
for snippet in snippets:
# Default max_length is 150 + some context
assert len(snippet) <= 200 # Some buffer for ellipsis
class TestFullPipeline:
"""Test the complete WebVTT to snippets pipeline."""
def test_webvtt_to_snippets_integration(self):
"""Test full pipeline from WebVTT to search snippets."""
# Create WebVTT with well-separated content for multiple snippets
webvtt = (
"""WEBVTT
00:00:00.000 --> 00:00:10.000
<v Speaker0>Let's discuss machine learning applications in modern technology.
00:00:10.000 --> 00:00:20.000
<v Speaker1>"""
+ "Various industries are adopting new technologies. " * 10
+ """
00:00:20.000 --> 00:00:30.000
<v Speaker2>Machine learning is revolutionizing healthcare and diagnostics.
00:00:30.000 --> 00:00:40.000
<v Speaker3>"""
+ "Financial markets show interesting patterns. " * 10
+ """
00:00:40.000 --> 00:00:50.000
<v Speaker0>Machine learning in education provides personalized experiences.
"""
)
# Extract and generate snippets
plain_text = SearchController._extract_webvtt_text(webvtt)
snippets = SearchController._generate_snippets(plain_text, "machine learning")
# Should find at least 2 snippets (text might still be close together)
assert len(snippets) >= 1 # At minimum one snippet containing matches
assert len(snippets) <= 3 # At most 3 by default
# No WebVTT artifacts in snippets
for snippet in snippets:
assert "machine learning" in snippet.lower()
assert "<v Speaker" not in snippet
assert "00:00" not in snippet
assert "-->" not in snippet

View File

@@ -1,4 +1,5 @@
import asyncio
import time
import pytest
from httpx import AsyncClient
@@ -39,14 +40,18 @@ async def test_transcript_process(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish
while True:
# wait for processing to finish (max 10 minutes)
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# restart the processing
response = await ac.post(
@@ -55,14 +60,18 @@ async def test_transcript_process(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish
while True:
# wait for processing to finish (max 10 minutes)
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
# check the transcript is ended
transcript = resp.json()

View File

@@ -6,6 +6,7 @@
import asyncio
import json
import threading
import time
from pathlib import Path
import pytest
@@ -21,14 +22,31 @@ class ThreadedUvicorn:
async def start(self):
self.thread.start()
while not self.server.started:
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (
not self.server.started
and (time.monotonic() - start_time) < timeout_seconds
):
await asyncio.sleep(0.1)
if not self.server.started:
raise TimeoutError(
f"Server failed to start after {timeout_seconds} seconds"
)
def stop(self):
if self.thread.is_alive():
self.server.should_exit = True
while self.thread.is_alive():
continue
timeout_seconds = 600 # 10 minutes
start_time = time.time()
while (
self.thread.is_alive() and (time.time() - start_time) < timeout_seconds
):
time.sleep(0.1)
if self.thread.is_alive():
raise TimeoutError(
f"Thread failed to stop after {timeout_seconds} seconds"
)
@pytest.fixture
@@ -92,12 +110,16 @@ async def test_transcript_rtc_and_websocket(
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
else:
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
@@ -145,9 +167,12 @@ async def test_transcript_rtc_and_websocket(
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
raise TimeoutError("Transcript processing failed")
# stop websocket task
websocket_task.cancel()
@@ -253,12 +278,16 @@ async def test_transcript_rtc_and_websocket_and_fr(
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
else:
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
@@ -310,9 +339,12 @@ async def test_transcript_rtc_and_websocket_and_fr(
if resp.json()["status"] == "ended":
break
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
raise TimeoutError("Transcript processing failed")
await asyncio.sleep(2)

View File

@@ -1,4 +1,5 @@
import asyncio
import time
import pytest
from httpx import AsyncClient
@@ -39,14 +40,18 @@ async def test_transcript_upload_file(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait the processing to finish
while True:
# wait the processing to finish (max 10 minutes)
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
# check the transcript is ended
transcript = resp.json()

151
server/tests/test_webvtt.py Normal file
View File

@@ -0,0 +1,151 @@
"""Tests for WebVTT utilities."""
import pytest
from reflector.processors.types import Transcript, Word, words_to_segments
from reflector.utils.webvtt import topics_to_webvtt, words_to_webvtt
class TestWordsToWebVTT:
"""Test words_to_webvtt function with TDD approach."""
def test_empty_words_returns_empty_webvtt(self):
"""Should return empty WebVTT structure for empty words list."""
result = words_to_webvtt([])
assert "WEBVTT" in result
assert result.strip() == "WEBVTT"
def test_single_word_creates_single_caption(self):
"""Should create one caption for a single word."""
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
result = words_to_webvtt(words)
assert "WEBVTT" in result
assert "00:00:00.000 --> 00:00:01.000" in result
assert "Hello" in result
assert "<v Speaker0>" in result
def test_multiple_words_same_speaker_groups_properly(self):
"""Should group consecutive words from same speaker."""
words = [
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text=" world", start=0.5, end=1.0, speaker=0),
]
result = words_to_webvtt(words)
assert "WEBVTT" in result
assert "Hello world" in result
assert "<v Speaker0>" in result
def test_speaker_change_creates_new_caption(self):
"""Should create new caption when speaker changes."""
words = [
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text="Hi", start=0.6, end=1.0, speaker=1),
]
result = words_to_webvtt(words)
lines = result.split("\n")
assert "<v Speaker0>" in result
assert "<v Speaker1>" in result
assert "Hello" in result
assert "Hi" in result
def test_punctuation_creates_segment_boundary(self):
"""Should respect punctuation boundaries from segmentation logic."""
words = [
Word(text="Hello.", start=0.0, end=0.5, speaker=0),
Word(text=" How", start=0.6, end=1.0, speaker=0),
Word(text=" are", start=1.0, end=1.3, speaker=0),
Word(text=" you?", start=1.3, end=1.8, speaker=0),
]
result = words_to_webvtt(words)
assert "WEBVTT" in result
assert "<v Speaker0>" in result
class TestTopicsToWebVTT:
"""Test topics_to_webvtt function."""
def test_empty_topics_returns_empty_webvtt(self):
"""Should handle empty topics list."""
result = topics_to_webvtt([])
assert "WEBVTT" in result
assert result.strip() == "WEBVTT"
def test_extracts_words_from_topics(self):
"""Should extract all words from topics in sequence."""
class MockTopic:
def __init__(self, words):
self.words = words
topics = [
MockTopic(
[
Word(text="First", start=0.0, end=0.5, speaker=1),
Word(text="Second", start=1.0, end=1.5, speaker=0),
]
)
]
result = topics_to_webvtt(topics)
assert "WEBVTT" in result
first_pos = result.find("First")
second_pos = result.find("Second")
assert first_pos < second_pos
def test_non_sequential_topics_raises_assertion(self):
"""Should raise assertion error when words are not in chronological sequence."""
class MockTopic:
def __init__(self, words):
self.words = words
topics = [
MockTopic(
[
Word(text="Second", start=1.0, end=1.5, speaker=0),
Word(text="First", start=0.0, end=0.5, speaker=1),
]
)
]
with pytest.raises(AssertionError) as exc_info:
topics_to_webvtt(topics)
assert "Words are not in sequence" in str(exc_info.value)
assert "Second and First" in str(exc_info.value)
class TestTranscriptWordsToSegments:
"""Test static words_to_segments method (TDD for making it static)."""
def test_static_method_exists(self):
"""Should have static words_to_segments method."""
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
segments = words_to_segments(words)
assert isinstance(segments, list)
assert len(segments) == 1
assert segments[0].text == "Hello"
assert segments[0].speaker == 0
def test_backward_compatibility(self):
"""Should maintain backward compatibility with instance method."""
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
transcript = Transcript(words=words)
segments = transcript.as_segments()
assert isinstance(segments, list)
assert len(segments) == 1
assert segments[0].text == "Hello"

View File

@@ -0,0 +1,34 @@
"""Test WebVTT auto-update functionality and edge cases."""
import pytest
from reflector.db.transcripts import (
TranscriptController,
)
@pytest.mark.asyncio
class TestWebVTTAutoUpdateImplementation:
async def test_handle_topics_update_handles_dict_conversion(self):
"""
Verify that _handle_topics_update() properly converts dict data to TranscriptTopic objects.
"""
values = {
"topics": [
{
"id": "topic1",
"title": "Test",
"summary": "Test",
"timestamp": 0.0,
"words": [
{"text": "Hello", "start": 0.0, "end": 1.0, "speaker": 0}
],
}
]
}
updated_values = TranscriptController._handle_topics_update(values)
assert "webvtt" in updated_values
assert updated_values["webvtt"] is not None
assert "WEBVTT" in updated_values["webvtt"]

View File

@@ -0,0 +1,234 @@
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
import pytest
from reflector.db import database
from reflector.db.transcripts import (
SourceKind,
TranscriptController,
TranscriptTopic,
transcripts,
)
from reflector.processors.types import Word
@pytest.mark.asyncio
class TestWebVTTAutoUpdate:
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self):
"""WebVTT should be None when creating transcript without topics."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
assert result["webvtt"] is None
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_on_upsert_topic(self):
"""WebVTT should update when upserting topics via upsert_topic method."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic = TranscriptTopic(
id="topic1",
title="Test Topic",
summary="Test summary",
timestamp=0.0,
words=[
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text=" world", start=0.5, end=1.0, speaker=0),
],
)
await controller.upsert_topic(transcript, topic)
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
assert "Hello world" in webvtt
assert "<v Speaker0>" in webvtt
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_on_direct_topics_update(self):
"""WebVTT should update when updating topics field directly."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topics_data = [
{
"id": "topic1",
"title": "First Topic",
"summary": "First sentence test",
"timestamp": 0.0,
"words": [
{"text": "First", "start": 0.0, "end": 0.5, "speaker": 0},
{"text": " sentence", "start": 0.5, "end": 1.0, "speaker": 0},
],
}
]
await controller.update(transcript, {"topics": topics_data})
# Fetch from DB
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
assert "First sentence" in webvtt
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_manually_with_handle_topics_update(self):
"""Test that _handle_topics_update works when called manually."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic1 = TranscriptTopic(
id="topic1",
title="Topic 1",
summary="Manual test",
timestamp=0.0,
words=[
Word(text="Manual", start=0.0, end=0.5, speaker=0),
Word(text=" test", start=0.5, end=1.0, speaker=0),
],
)
transcript.upsert_topic(topic1)
values = {"topics": transcript.topics_dump()}
await controller.update(transcript, values)
# Fetch from DB
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
assert "Manual test" in webvtt
assert "<v Speaker0>" in webvtt
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_update_with_non_sequential_topics_fails(self):
"""Test that non-sequential topics raise assertion error."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic1 = TranscriptTopic(
id="topic1",
title="Bad Topic",
summary="Bad order test",
timestamp=1.0,
words=[
Word(text="Second", start=2.0, end=2.5, speaker=0),
Word(text="First", start=1.0, end=1.5, speaker=0),
],
)
transcript.upsert_topic(topic1)
values = {"topics": transcript.topics_dump()}
with pytest.raises(AssertionError) as exc_info:
TranscriptController._handle_topics_update(values)
assert "Words are not in sequence" in str(exc_info.value)
finally:
await controller.remove_by_id(transcript.id)
async def test_multiple_speakers_in_webvtt(self):
"""Test WebVTT generation with multiple speakers."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic = TranscriptTopic(
id="topic1",
title="Multi Speaker",
summary="Multi speaker test",
timestamp=0.0,
words=[
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text="Hi", start=1.0, end=1.5, speaker=1),
Word(text="Goodbye", start=2.0, end=2.5, speaker=0),
],
)
transcript.upsert_topic(topic)
values = {"topics": transcript.topics_dump()}
await controller.update(transcript, values)
# Fetch from DB
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "<v Speaker0>" in webvtt
assert "<v Speaker1>" in webvtt
assert "Hello" in webvtt
assert "Hi" in webvtt
assert "Goodbye" in webvtt
finally:
await controller.remove_by_id(transcript.id)