From fe85005e8e15700c716cae8b17456c530c2128ae Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 27 Jul 2023 18:04:26 +0200 Subject: [PATCH] server: add basic rtc test with local audio and fake llm --- server/poetry.lock | 85 ++++++++++++++++++++++++++++++- server/pyproject.toml | 6 +++ server/reflector/server.py | 26 ++++++---- server/reflector/stream_client.py | 55 +++++++++----------- server/tests/test_basic_rtc.py | 66 ++++++++++++++++++++++++ 5 files changed, 198 insertions(+), 40 deletions(-) create mode 100644 server/tests/test_basic_rtc.py diff --git a/server/poetry.lock b/server/poetry.lock index d9157b1c..1d871c03 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1054,6 +1054,17 @@ files = [ {file = "ifaddr-0.2.0.tar.gz", hash = "sha256:cc0cbfcaabf765d44595825fb96a99bb12c79716b73b44330ea38ee2b0c4aed4"}, ] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "loguru" version = "0.7.0" @@ -1295,6 +1306,21 @@ files = [ docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"] +[[package]] +name = "pluggy" +version = "1.2.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, + {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "protobuf" version = "4.23.4" @@ -1606,6 +1632,63 @@ files = [ {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, ] +[[package]] +name = "pytest" +version = "7.4.0" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, + {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-aiohttp" +version = "1.0.4" +description = "Pytest plugin for aiohttp support" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-aiohttp-1.0.4.tar.gz", hash = "sha256:39ff3a0d15484c01d1436cbedad575c6eafbf0f57cdf76fb94994c97b5b8c5a4"}, + {file = "pytest_aiohttp-1.0.4-py3-none-any.whl", hash = "sha256:1d2dc3a304c2be1fd496c0c2fb6b31ab60cd9fc33984f761f951f8ea1eb4ca95"}, +] + +[package.dependencies] +aiohttp = ">=3.8.1" +pytest = ">=6.1.0" +pytest-asyncio = ">=0.17.2" + +[package.extras] +testing = ["coverage (==6.2)", "mypy (==0.931)"] + +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "python-dotenv" version = "1.0.0" @@ -1983,4 +2066,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "32981f838c232fdf2274aadbc933ef107c820d053bc9c2ceec563b2a22c1ea4c" +content-hash = "d2b64390d1ea9038b6703b12060cdde1970b680a0ad891f24405323ff2ca0a60" diff --git a/server/pyproject.toml b/server/pyproject.toml index aebb26f2..c6704a27 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -29,6 +29,12 @@ httpx = "^0.24.1" pyaudio = "^0.2.13" stamina = "^23.1.0" + +[tool.poetry.group.tests.dependencies] +pytest-aiohttp = "^1.0.4" +pytest-asyncio = "^0.21.1" +pytest = "^7.4.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/server/reflector/server.py b/server/reflector/server.py index dff8b15b..8e28b583 100644 --- a/server/reflector/server.py +++ b/server/reflector/server.py @@ -348,15 +348,10 @@ async def on_shutdown(application: web.Application) -> NoReturn: pcs.clear() -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="WebRTC based server for Reflector") - parser.add_argument( - "--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=1250, help="Server port (def: 1250)" - ) - args = parser.parse_args() +def create_app() -> web.Application: + """ + Create the web application + """ app = web.Application() cors = aiohttp_cors.setup( app, @@ -370,4 +365,17 @@ if __name__ == "__main__": offer_resource = cors.add(app.router.add_resource("/offer")) cors.add(offer_resource.add_route("POST", offer)) app.on_shutdown.append(on_shutdown) + return app + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="WebRTC based server for Reflector") + parser.add_argument( + "--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=1250, help="Server port (def: 1250)" + ) + args = parser.parse_args() + app = create_app() web.run_app(app, access_log=None, host=args.host, port=args.port) diff --git a/server/reflector/stream_client.py b/server/reflector/stream_client.py index e6d5f497..39b0e0db 100644 --- a/server/reflector/stream_client.py +++ b/server/reflector/stream_client.py @@ -4,7 +4,6 @@ import uuid import httpx import pyaudio -import requests import stamina from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaPlayer, MediaRelay @@ -15,7 +14,7 @@ from reflector.settings import settings class StreamClient: def __init__( - self, signaling, url="http://0.0.0.0:1250", play_from=None, ping_pong=False + self, signaling, url="http://0.0.0.0:1250/offer", play_from=None, ping_pong=False ): self.signaling = signaling self.server_url = url @@ -25,21 +24,15 @@ class StreamClient: self.pc = RTCPeerConnection() - self.loop = asyncio.get_event_loop() self.relay = None self.pcs = set() self.time_start = None self.queue = asyncio.Queue() - self.player = MediaPlayer( - f":{settings.AUDIO_AV_FOUNDATION_DEVICE_ID}", - format="avfoundation", - options={"channels": "2"}, - ) + self.logger = logger.bind(stream_client=id(self)) - def stop(self): - self.loop.run_until_complete(self.signaling.close()) - self.loop.run_until_complete(self.pc.close()) - # self.loop.close() + async def stop(self): + await self.signaling.close() + await self.pc.close() def create_local_tracks(self, play_from): if play_from: @@ -48,11 +41,13 @@ class StreamClient: else: if self.relay is None: self.relay = MediaRelay() + self.player = MediaPlayer( + f":{settings.AUDIO_AV_FOUNDATION_DEVICE_ID}", + format="avfoundation", + options={"channels": "2"}, + ) return self.relay.subscribe(self.player.audio), None - def channel_log(self, channel, t, message): - print("channel(%s) %s %s" % (channel.label, t, message)) - def channel_send(self, channel, message): # self.channel_log(channel, ">", message) channel.send(message) @@ -67,32 +62,31 @@ class StreamClient: async def run_offer(self, pc, signaling): # microphone audio, video = self.create_local_tracks(self.play_from) - pc_id = "PeerConnection(%s)" % uuid.uuid4() + pc_id = uuid.uuid4().hex self.pcs.add(pc) - - def log_info(msg, *args): - logger.info(pc_id + " " + msg, *args) + self.logger = self.logger.bind(pc_id=pc_id) @pc.on("connectionstatechange") async def on_connectionstatechange(): - print("Connection state is %s" % pc.connectionState) + self.logger.info(f"Connection state is {pc.connectionState}") if pc.connectionState == "failed": await pc.close() self.pcs.discard(pc) @pc.on("track") def on_track(track): - print("Sending %s" % track.kind) + self.logger.info(f"Sending {track.kind}") self.pc.addTrack(track) @track.on("ended") async def on_ended(): - log_info("Track %s ended", track.kind) + self.logger.info(f"Track {track.kind} ended") self.pc.addTrack(audio) channel = pc.createDataChannel("data-channel") - self.channel_log(channel, "-", "created by local party") + self.logger = self.logger.bind(channel=channel.label) + self.logger.info("Created by local party") async def send_pings(): while True: @@ -108,23 +102,24 @@ class StreamClient: def on_message(message): self.queue.put_nowait(message) if self.ping_pong: - self.channel_log(channel, "<", message) + self.logger.info(f"Message: {message}") if isinstance(message, str) and message.startswith("pong"): elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000 - print(" RTT %.2f ms" % elapsed_ms) + self.logger.debug("RTT %.2f ms" % elapsed_ms) await pc.setLocalDescription(await pc.createOffer()) sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} @stamina.retry(on=httpx.HTTPError, attempts=5) - def connect_to_server(): - response = requests.post(self.server_url, json=sdp, timeout=10) - response.raise_for_status() - return response + async def connect_to_server(): + async with httpx.AsyncClient() as client: + response = await client.post(self.server_url, json=sdp, timeout=10) + response.raise_for_status() + return response.json() - params = connect_to_server().json() + params = await connect_to_server() answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) await pc.setRemoteDescription(answer) diff --git a/server/tests/test_basic_rtc.py b/server/tests/test_basic_rtc.py new file mode 100644 index 00000000..98d6fca7 --- /dev/null +++ b/server/tests/test_basic_rtc.py @@ -0,0 +1,66 @@ +import pytest +from unittest.mock import patch + + +@pytest.mark.asyncio +async def test_basic_rtc_server(aiohttp_server, event_loop): + # goal is to start the server, and send rtc audio to it + # validate the events received + import argparse + import json + from pathlib import Path + from reflector.server import create_app + from reflector.stream_client import StreamClient + from reflector.models import TitleSummaryOutput + from aiortc.contrib.signaling import add_signaling_arguments, create_signaling + + # customize settings to have a mock LLM server + with patch("reflector.server.get_title_and_summary") as mock_llm: + # any response from mock_llm will be test topic + mock_llm.return_value = TitleSummaryOutput(["topic_test"]) + + # create the server + app = create_app() + server = await aiohttp_server(app) + url = f"http://{server.host}:{server.port}/offer" + + # create signaling + parser = argparse.ArgumentParser() + add_signaling_arguments(parser) + args = parser.parse_args(["-s", "tcp-socket"]) + signaling = create_signaling(args) + + # create the client + path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" + client = StreamClient(signaling, url=url, play_from=path.as_posix()) + await client.start() + + # we just want the first transcription + # and topic update messages + + marks = { + "SHOW_TRANSCRIPTION": False, + "UPDATE_TOPICS": False, + } + + async for rawmsg in client.get_reader(): + msg = json.loads(rawmsg) + cmd = msg["cmd"] + if cmd == "SHOW_TRANSCRIPTION": + assert "text" in msg + assert "want to share my incredible experience" in msg["text"] + elif cmd == "UPDATE_TOPICS": + assert "topics" in msg + assert "topic_test" in msg["topics"] + marks[cmd] = True + + # break if we have all the events we need + if all(marks.values()): + break + + # stop the server + await server.close() + await client.stop() + + +