mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
- Import topics_to_webvtt_named and recordings controller - Add _get_is_multitrack helper function - Generate WebVTT context on WebSocket connection - Add get_context message type to retrieve WebVTT - Maintain backward compatibility with echo for other messages - Add test fixture and test for WebVTT context generation Implements task fn-1.2: WebVTT context generation for transcript chat
158 lines
5.2 KiB
Python
158 lines
5.2 KiB
Python
"""Tests for transcript chat WebSocket endpoint."""
|
|
|
|
import pytest
|
|
|
|
from reflector.db.transcripts import (
|
|
SourceKind,
|
|
TranscriptParticipant,
|
|
TranscriptTopic,
|
|
transcripts_controller,
|
|
)
|
|
from reflector.processors.types import Word
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_transcript(setup_database):
|
|
"""Create a test transcript for WebSocket tests."""
|
|
transcript = await transcripts_controller.add(
|
|
name="Test Transcript for Chat", source_kind=SourceKind.FILE
|
|
)
|
|
return transcript
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_transcript_with_content(setup_database):
|
|
"""Create a test transcript with actual content for WebVTT generation."""
|
|
transcript = await transcripts_controller.add(
|
|
name="Test Transcript with Content", source_kind=SourceKind.FILE
|
|
)
|
|
|
|
# Add participants
|
|
await transcripts_controller.update(
|
|
transcript,
|
|
{
|
|
"participants": [
|
|
TranscriptParticipant(id="1", speaker=0, name="Alice").model_dump(),
|
|
TranscriptParticipant(id="2", speaker=1, name="Bob").model_dump(),
|
|
]
|
|
},
|
|
)
|
|
|
|
# Add topic with words
|
|
await transcripts_controller.upsert_topic(
|
|
transcript,
|
|
TranscriptTopic(
|
|
title="Introduction",
|
|
summary="Opening remarks",
|
|
timestamp=0.0,
|
|
words=[
|
|
Word(text="Hello ", start=0.0, end=1.0, speaker=0),
|
|
Word(text="everyone.", start=1.0, end=2.0, speaker=0),
|
|
Word(text="Hi ", start=2.0, end=3.0, speaker=1),
|
|
Word(text="there!", start=3.0, end=4.0, speaker=1),
|
|
],
|
|
),
|
|
)
|
|
|
|
return transcript
|
|
|
|
|
|
def test_chat_websocket_connection_success(test_transcript):
|
|
"""Test successful WebSocket connection to chat endpoint."""
|
|
from starlette.testclient import TestClient
|
|
|
|
from reflector.app import app
|
|
|
|
with TestClient(app) as client:
|
|
# Connect to WebSocket endpoint
|
|
with client.websocket_connect(
|
|
f"/v1/transcripts/{test_transcript.id}/chat"
|
|
) as websocket:
|
|
# Send a test message
|
|
websocket.send_json({"type": "message", "text": "Hello"})
|
|
|
|
# Receive echo response
|
|
response = websocket.receive_json()
|
|
assert response["type"] == "echo"
|
|
assert response["data"]["type"] == "message"
|
|
assert response["data"]["text"] == "Hello"
|
|
|
|
|
|
def test_chat_websocket_nonexistent_transcript():
|
|
"""Test WebSocket connection fails for nonexistent transcript."""
|
|
from starlette.testclient import TestClient
|
|
from starlette.websockets import WebSocketDisconnect
|
|
|
|
from reflector.app import app
|
|
|
|
with TestClient(app) as client:
|
|
# Try to connect to non-existent transcript - should raise on connect
|
|
with pytest.raises(WebSocketDisconnect):
|
|
with client.websocket_connect(
|
|
"/v1/transcripts/nonexistent-id/chat"
|
|
) as websocket:
|
|
websocket.send_json({"type": "message", "text": "Hello"})
|
|
|
|
|
|
def test_chat_websocket_multiple_messages(test_transcript):
|
|
"""Test sending multiple messages through WebSocket."""
|
|
from starlette.testclient import TestClient
|
|
|
|
from reflector.app import app
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect(
|
|
f"/v1/transcripts/{test_transcript.id}/chat"
|
|
) as websocket:
|
|
# Send multiple messages
|
|
messages = ["First message", "Second message", "Third message"]
|
|
|
|
for msg in messages:
|
|
websocket.send_json({"type": "message", "text": msg})
|
|
response = websocket.receive_json()
|
|
assert response["type"] == "echo"
|
|
assert response["data"]["text"] == msg
|
|
|
|
|
|
def test_chat_websocket_disconnect_graceful(test_transcript):
|
|
"""Test WebSocket disconnects gracefully."""
|
|
from starlette.testclient import TestClient
|
|
|
|
from reflector.app import app
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect(
|
|
f"/v1/transcripts/{test_transcript.id}/chat"
|
|
) as websocket:
|
|
websocket.send_json({"type": "message", "text": "Hello"})
|
|
websocket.receive_json()
|
|
# Close connection - context manager handles it
|
|
# No exception should be raised
|
|
|
|
|
|
def test_chat_websocket_context_generation(test_transcript_with_content):
|
|
"""Test WebVTT context is generated on connection."""
|
|
from starlette.testclient import TestClient
|
|
|
|
from reflector.app import app
|
|
|
|
with TestClient(app) as client:
|
|
with client.websocket_connect(
|
|
f"/v1/transcripts/{test_transcript_with_content.id}/chat"
|
|
) as websocket:
|
|
# Send request for context (new message type)
|
|
websocket.send_json({"type": "get_context"})
|
|
|
|
# Receive context response
|
|
response = websocket.receive_json()
|
|
assert response["type"] == "context"
|
|
assert "webvtt" in response
|
|
|
|
# Verify WebVTT format
|
|
webvtt = response["webvtt"]
|
|
assert webvtt.startswith("WEBVTT")
|
|
assert "<v Alice>" in webvtt
|
|
assert "<v Bob>" in webvtt
|
|
assert "Hello everyone." in webvtt
|
|
assert "Hi there!" in webvtt
|