Files
reflector/server/tests/test_transcripts_rtc_ws.py
Mathieu Virbel 5267ab2d37 feat: retake summary using NousResearch/Hermes-3-Llama-3.1-8B model (#415)
This feature a new modal endpoint, and a complete new way to build the
summary.

## SummaryBuilder

The summary builder is based on conversational model, where an exchange
between the model and the user is made. This allow more context
inclusion and a better respect of the rules.

It requires an endpoint with OpenAI-like completions endpoint
(/v1/chat/completions)

## vLLM Hermes3

Unlike previous deployment, this one use vLLM, which gives OpenAI-like
completions endpoint out of the box. It could also handle guided JSON
generation, so jsonformer is not needed. But, the model is quite good to
follow JSON schema if asked in the prompt.

## Conversion of long/short into summary builder

The builder is identifying participants, find key subjects, get a
summary for each, then get a quick recap.

The quick recap is used as a short_summary, while the markdown including
the quick recap + key subjects + summaries are used for the
long_summary.

This is why the nextjs component has to be updated, to correctly style
h1 and keep the new line of the markdown.
2024-09-14 02:28:38 +02:00

362 lines
12 KiB
Python

# === further tests
# FIXME test status of transcript
# FIXME test websocket connection after RTC is finished still send the full events
# FIXME try with locked session, RTC should not work
import asyncio
import json
import threading
from pathlib import Path
import pytest
from httpx import AsyncClient
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
class ThreadedUvicorn:
def __init__(self, config: Config):
self.server = Server(config)
self.thread = threading.Thread(daemon=True, target=self.server.run)
async def start(self):
self.thread.start()
while not self.server.started:
await asyncio.sleep(0.1)
def stop(self):
if self.thread.is_alive():
self.server.should_exit = True
while self.thread.is_alive():
continue
@pytest.fixture
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
from reflector.settings import settings
from reflector.app import app
DATA_DIR = settings.DATA_DIR
settings.DATA_DIR = Path(tmpdir)
# start server
host = "127.0.0.1"
port = 1255
config = Config(app=app, host=host, port=port)
server = ThreadedUvicorn(config)
await server.start()
yield (server, host, port)
server.stop()
settings.DATA_DIR = DATA_DIR
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
dummy_llm,
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
nltk,
appserver,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
server, host, port = appserver
# create a transcript
base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url)
response = await ac.post("/transcripts", json={"name": "Test RTC"})
assert response.status_code == 200
tid = response.json()["id"]
# create a websocket connection as a task
events = []
async def websocket_task():
print("Test websocket: TASK STARTED")
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
ws.close()
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
from reflector.stream_client import StreamClient
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
timeout = 20
while not client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
await client.stop()
# wait the processing to finish
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
# stop websocket task
websocket_task.cancel()
# check events
assert len(events) > 0
from pprint import pprint
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
assert len(waveform_resp.json()["data"]) >= 250
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
assert "DURATION" in eventnames
# check that audio/mp3 is available
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,
dummy_llm,
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
nltk,
appserver,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
# with target french language
server, host, port = appserver
# create a transcript
base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url)
response = await ac.post(
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
)
assert response.status_code == 200
tid = response.json()["id"]
# create a websocket connection as a task
events = []
async def websocket_task():
print("Test websocket: TASK STARTED")
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
ws.close()
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
from reflector.stream_client import StreamClient
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
timeout = 20
while not client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
await asyncio.sleep(2)
# stop websocket task
websocket_task.cancel()
# check events
assert len(events) > 0
from pprint import pprint
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"