Files
reflector/server/tests/test_transcripts_rtc_ws.py

167 lines
5.1 KiB
Python

# === further tests
# FIXME test status of transcript
# FIXME test websocket connection after RTC is finished still send the full events
# FIXME try with locked session, RTC should not work
import pytest
import json
from unittest.mock import patch
from httpx import AsyncClient
from reflector.app import app
from uvicorn import Config, Server
import threading
import asyncio
from pathlib import Path
from httpx_ws import aconnect_ws
class ThreadedUvicorn:
def __init__(self, config: Config):
self.server = Server(config)
self.thread = threading.Thread(daemon=True, target=self.server.run)
async def start(self):
self.thread.start()
while not self.server.started:
await asyncio.sleep(0.1)
def stop(self):
if self.thread.is_alive():
self.server.should_exit = True
while self.thread.is_alive():
continue
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
return Transcript(
text="Hello world",
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text="world"),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
# start server
host = "127.0.0.1"
port = 1255
base_url = f"http://{host}:{port}/v1"
config = Config(app=app, host=host, port=port)
server = ThreadedUvicorn(config)
await server.start()
# create a transcript
ac = AsyncClient(base_url=base_url)
response = await ac.post("/transcripts", json={"name": "Test RTC"})
assert response.status_code == 200
tid = response.json()["id"]
# create a websocket connection as a task
events = []
async def websocket_task():
print("Test websocket: TASK STARTED")
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
ws.close()
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
# create stream client
import argparse
from reflector.stream_client import StreamClient
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
timeout = 20
while not client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
await asyncio.sleep(2)
# stop websocket task
websocket_task.cancel()
# check events
assert len(events) > 0
assert events[0]["event"] == "TRANSCRIPT"
assert events[0]["data"]["text"] == "Hello world"
assert events[-2]["event"] == "TOPIC"
assert events[-2]["data"]["id"]
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
assert events[-2]["data"]["transcript"].startswith("Hello world")
assert events[-2]["data"]["timestamp"] == 0.0
assert events[-1]["event"] == "FINAL_SUMMARY"
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
# stop server
# server.stop()