server: add basic rtc test with local audio and fake llm

This commit is contained in:
Mathieu Virbel
2023-07-27 18:04:26 +02:00
parent ee080e1ab2
commit fe85005e8e
5 changed files with 198 additions and 40 deletions

85
server/poetry.lock generated
View File

@@ -1054,6 +1054,17 @@ files = [
{file = "ifaddr-0.2.0.tar.gz", hash = "sha256:cc0cbfcaabf765d44595825fb96a99bb12c79716b73b44330ea38ee2b0c4aed4"}, {file = "ifaddr-0.2.0.tar.gz", hash = "sha256:cc0cbfcaabf765d44595825fb96a99bb12c79716b73b44330ea38ee2b0c4aed4"},
] ]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]] [[package]]
name = "loguru" name = "loguru"
version = "0.7.0" version = "0.7.0"
@@ -1295,6 +1306,21 @@ files = [
docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"]
[[package]]
name = "pluggy"
version = "1.2.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.7"
files = [
{file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"},
{file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "4.23.4" version = "4.23.4"
@@ -1606,6 +1632,63 @@ files = [
{file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"},
] ]
[[package]]
name = "pytest"
version = "7.4.0"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"},
{file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"},
]
[package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-aiohttp"
version = "1.0.4"
description = "Pytest plugin for aiohttp support"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-aiohttp-1.0.4.tar.gz", hash = "sha256:39ff3a0d15484c01d1436cbedad575c6eafbf0f57cdf76fb94994c97b5b8c5a4"},
{file = "pytest_aiohttp-1.0.4-py3-none-any.whl", hash = "sha256:1d2dc3a304c2be1fd496c0c2fb6b31ab60cd9fc33984f761f951f8ea1eb4ca95"},
]
[package.dependencies]
aiohttp = ">=3.8.1"
pytest = ">=6.1.0"
pytest-asyncio = ">=0.17.2"
[package.extras]
testing = ["coverage (==6.2)", "mypy (==0.931)"]
[[package]]
name = "pytest-asyncio"
version = "0.21.1"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"},
{file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"},
]
[package.dependencies]
pytest = ">=7.0.0"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]] [[package]]
name = "python-dotenv" name = "python-dotenv"
version = "1.0.0" version = "1.0.0"
@@ -1983,4 +2066,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "32981f838c232fdf2274aadbc933ef107c820d053bc9c2ceec563b2a22c1ea4c" content-hash = "d2b64390d1ea9038b6703b12060cdde1970b680a0ad891f24405323ff2ca0a60"

View File

@@ -29,6 +29,12 @@ httpx = "^0.24.1"
pyaudio = "^0.2.13" pyaudio = "^0.2.13"
stamina = "^23.1.0" stamina = "^23.1.0"
[tool.poetry.group.tests.dependencies]
pytest-aiohttp = "^1.0.4"
pytest-asyncio = "^0.21.1"
pytest = "^7.4.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@@ -348,15 +348,10 @@ async def on_shutdown(application: web.Application) -> NoReturn:
pcs.clear() pcs.clear()
if __name__ == "__main__": def create_app() -> web.Application:
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector") """
parser.add_argument( Create the web application
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" """
)
parser.add_argument(
"--port", type=int, default=1250, help="Server port (def: 1250)"
)
args = parser.parse_args()
app = web.Application() app = web.Application()
cors = aiohttp_cors.setup( cors = aiohttp_cors.setup(
app, app,
@@ -370,4 +365,17 @@ if __name__ == "__main__":
offer_resource = cors.add(app.router.add_resource("/offer")) offer_resource = cors.add(app.router.add_resource("/offer"))
cors.add(offer_resource.add_route("POST", offer)) cors.add(offer_resource.add_route("POST", offer))
app.on_shutdown.append(on_shutdown) app.on_shutdown.append(on_shutdown)
return app
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
parser.add_argument(
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=1250, help="Server port (def: 1250)"
)
args = parser.parse_args()
app = create_app()
web.run_app(app, access_log=None, host=args.host, port=args.port) web.run_app(app, access_log=None, host=args.host, port=args.port)

View File

@@ -4,7 +4,6 @@ import uuid
import httpx import httpx
import pyaudio import pyaudio
import requests
import stamina import stamina
from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaPlayer, MediaRelay from aiortc.contrib.media import MediaPlayer, MediaRelay
@@ -15,7 +14,7 @@ from reflector.settings import settings
class StreamClient: class StreamClient:
def __init__( def __init__(
self, signaling, url="http://0.0.0.0:1250", play_from=None, ping_pong=False self, signaling, url="http://0.0.0.0:1250/offer", play_from=None, ping_pong=False
): ):
self.signaling = signaling self.signaling = signaling
self.server_url = url self.server_url = url
@@ -25,21 +24,15 @@ class StreamClient:
self.pc = RTCPeerConnection() self.pc = RTCPeerConnection()
self.loop = asyncio.get_event_loop()
self.relay = None self.relay = None
self.pcs = set() self.pcs = set()
self.time_start = None self.time_start = None
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.player = MediaPlayer( self.logger = logger.bind(stream_client=id(self))
f":{settings.AUDIO_AV_FOUNDATION_DEVICE_ID}",
format="avfoundation",
options={"channels": "2"},
)
def stop(self): async def stop(self):
self.loop.run_until_complete(self.signaling.close()) await self.signaling.close()
self.loop.run_until_complete(self.pc.close()) await self.pc.close()
# self.loop.close()
def create_local_tracks(self, play_from): def create_local_tracks(self, play_from):
if play_from: if play_from:
@@ -48,11 +41,13 @@ class StreamClient:
else: else:
if self.relay is None: if self.relay is None:
self.relay = MediaRelay() self.relay = MediaRelay()
self.player = MediaPlayer(
f":{settings.AUDIO_AV_FOUNDATION_DEVICE_ID}",
format="avfoundation",
options={"channels": "2"},
)
return self.relay.subscribe(self.player.audio), None return self.relay.subscribe(self.player.audio), None
def channel_log(self, channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(self, channel, message): def channel_send(self, channel, message):
# self.channel_log(channel, ">", message) # self.channel_log(channel, ">", message)
channel.send(message) channel.send(message)
@@ -67,32 +62,31 @@ class StreamClient:
async def run_offer(self, pc, signaling): async def run_offer(self, pc, signaling):
# microphone # microphone
audio, video = self.create_local_tracks(self.play_from) audio, video = self.create_local_tracks(self.play_from)
pc_id = "PeerConnection(%s)" % uuid.uuid4() pc_id = uuid.uuid4().hex
self.pcs.add(pc) self.pcs.add(pc)
self.logger = self.logger.bind(pc_id=pc_id)
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState) self.logger.info(f"Connection state is {pc.connectionState}")
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
self.pcs.discard(pc) self.pcs.discard(pc)
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
print("Sending %s" % track.kind) self.logger.info(f"Sending {track.kind}")
self.pc.addTrack(track) self.pc.addTrack(track)
@track.on("ended") @track.on("ended")
async def on_ended(): async def on_ended():
log_info("Track %s ended", track.kind) self.logger.info(f"Track {track.kind} ended")
self.pc.addTrack(audio) self.pc.addTrack(audio)
channel = pc.createDataChannel("data-channel") channel = pc.createDataChannel("data-channel")
self.channel_log(channel, "-", "created by local party") self.logger = self.logger.bind(channel=channel.label)
self.logger.info("Created by local party")
async def send_pings(): async def send_pings():
while True: while True:
@@ -108,23 +102,24 @@ class StreamClient:
def on_message(message): def on_message(message):
self.queue.put_nowait(message) self.queue.put_nowait(message)
if self.ping_pong: if self.ping_pong:
self.channel_log(channel, "<", message) self.logger.info(f"Message: {message}")
if isinstance(message, str) and message.startswith("pong"): if isinstance(message, str) and message.startswith("pong"):
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000 elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
print(" RTT %.2f ms" % elapsed_ms) self.logger.debug("RTT %.2f ms" % elapsed_ms)
await pc.setLocalDescription(await pc.createOffer()) await pc.setLocalDescription(await pc.createOffer())
sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
@stamina.retry(on=httpx.HTTPError, attempts=5) @stamina.retry(on=httpx.HTTPError, attempts=5)
def connect_to_server(): async def connect_to_server():
response = requests.post(self.server_url, json=sdp, timeout=10) async with httpx.AsyncClient() as client:
response = await client.post(self.server_url, json=sdp, timeout=10)
response.raise_for_status() response.raise_for_status()
return response return response.json()
params = connect_to_server().json() params = await connect_to_server()
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
await pc.setRemoteDescription(answer) await pc.setRemoteDescription(answer)

View File

@@ -0,0 +1,66 @@
import pytest
from unittest.mock import patch
@pytest.mark.asyncio
async def test_basic_rtc_server(aiohttp_server, event_loop):
# goal is to start the server, and send rtc audio to it
# validate the events received
import argparse
import json
from pathlib import Path
from reflector.server import create_app
from reflector.stream_client import StreamClient
from reflector.models import TitleSummaryOutput
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
# customize settings to have a mock LLM server
with patch("reflector.server.get_title_and_summary") as mock_llm:
# any response from mock_llm will be test topic
mock_llm.return_value = TitleSummaryOutput(["topic_test"])
# create the server
app = create_app()
server = await aiohttp_server(app)
url = f"http://{server.host}:{server.port}/offer"
# create signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
# create the client
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
# we just want the first transcription
# and topic update messages
marks = {
"SHOW_TRANSCRIPTION": False,
"UPDATE_TOPICS": False,
}
async for rawmsg in client.get_reader():
msg = json.loads(rawmsg)
cmd = msg["cmd"]
if cmd == "SHOW_TRANSCRIPTION":
assert "text" in msg
assert "want to share my incredible experience" in msg["text"]
elif cmd == "UPDATE_TOPICS":
assert "topics" in msg
assert "topic_test" in msg["topics"]
marks[cmd] = True
# break if we have all the events we need
if all(marks.values()):
break
# stop the server
await server.close()
await client.stop()