diff --git a/server/poetry.lock b/server/poetry.lock index b49d8df5..cf60f351 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -682,6 +682,78 @@ humanfriendly = ">=9.1" [package.extras] cron = ["capturer (>=2.4)"] +[[package]] +name = "coverage" +version = "7.2.7" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"}, + {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"}, + {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"}, + {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"}, + {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"}, + {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"}, + {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"}, + {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"}, + {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"}, + {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"}, + {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"}, + {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"}, + {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"}, + {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"}, + {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"}, + {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"}, + {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"}, + {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"}, + {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"}, + {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"}, + {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"}, + {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"}, + {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"}, + {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"}, +] + +[package.extras] +toml = ["tomli"] + [[package]] name = "cryptography" version = "41.0.2" @@ -1169,6 +1241,23 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "httpx-ws" +version = "0.4.1" +description = "WebSockets support for HTTPX" +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx_ws-0.4.1-py3-none-any.whl", hash = "sha256:01bdaeb66add8196485dc39912abd0a3e95b67c244aededc151156ac6adca850"}, + {file = "httpx_ws-0.4.1.tar.gz", hash = "sha256:5f3e291e8fb99c89f994329d883e5679d02a0b5b12a1e414f7f8630c276b6744"}, +] + +[package.dependencies] +anyio = "*" +httpcore = ">=0.17.3,<0.18" +httpx = ">=0.23.1" +wsproto = "*" + [[package]] name = "huggingface-hub" version = "0.16.4" @@ -1883,6 +1972,24 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -2537,6 +2644,20 @@ files = [ {file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"}, ] +[[package]] +name = "wsproto" +version = "1.2.0" +description = "WebSockets state-machine based protocol implementation" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"}, + {file = "wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065"}, +] + +[package.dependencies] +h11 = ">=0.9.0,<1" + [[package]] name = "yarl" version = "1.9.2" @@ -2627,4 +2748,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "a51d7d26b88683875685ede2298f0f02ab42b1f303657b47e0a5dee9be0dc9e6" +content-hash = "c984979825947f67fc42e4553d5ff347f2f9194e4acccdeacb67a406a332009a" diff --git a/server/pyproject.toml b/server/pyproject.toml index b4eb307a..da53c298 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -35,9 +35,11 @@ stamina = "^23.1.0" [tool.poetry.group.tests.dependencies] +pytest-cov = "^4.1.0" pytest-aiohttp = "^1.0.4" pytest-asyncio = "^0.21.1" pytest = "^7.4.0" +httpx-ws = "^0.4.1" [tool.poetry.group.aws.dependencies] @@ -46,3 +48,13 @@ aioboto3 = "^11.2.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.coverage.run] +source = ["reflector"] + +[tool.pytest.ini_options] +addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v" +testpaths = ["tests"] +asyncio_mode = "auto" + + diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 353c8aa4..bdf98b7a 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -1,9 +1,8 @@ -from dataclasses import dataclass +from pydantic import BaseModel from pathlib import Path -@dataclass -class AudioFile: +class AudioFile(BaseModel): path: Path sample_rate: int channels: int @@ -14,15 +13,13 @@ class AudioFile: self.path.unlink() -@dataclass -class Word: +class Word(BaseModel): text: str start: float end: float -@dataclass -class Transcript: +class Transcript(BaseModel): text: str = "" words: list[Word] = None @@ -59,8 +56,7 @@ class Transcript: return Transcript(text=self.text, words=words) -@dataclass -class TitleSummary: +class TitleSummary(BaseModel): title: str summary: str timestamp: float @@ -75,7 +71,6 @@ class TitleSummary: return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}" -@dataclass -class FinalSummary: +class FinalSummary(BaseModel): summary: str duration: float diff --git a/server/reflector/stream_client.py b/server/reflector/stream_client.py index 912bc514..6b66ad45 100644 --- a/server/reflector/stream_client.py +++ b/server/reflector/stream_client.py @@ -3,7 +3,6 @@ import time import uuid import httpx -import pyaudio import stamina from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaPlayer, MediaRelay @@ -24,7 +23,6 @@ class StreamClient: self.server_url = url self.play_from = play_from self.ping_pong = ping_pong - self.paudio = pyaudio.PyAudio() self.pc = RTCPeerConnection() @@ -87,6 +85,7 @@ class StreamClient: self.logger.info(f"Track {track.kind} ended") self.pc.addTrack(audio) + self.track_audio = audio channel = pc.createDataChannel("data-channel") self.logger = self.logger.bind(channel=channel.label) @@ -142,3 +141,6 @@ class StreamClient: coro = self.run_offer(self.pc, self.signaling) task = asyncio.create_task(coro) await task + + def is_ended(self): + return self.track_audio is None or self.track_audio.readyState == "ended" diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 332a960b..120c3ff1 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -3,7 +3,8 @@ from pydantic import BaseModel, Field from uuid import UUID, uuid4 from datetime import datetime from fastapi_pagination import Page, paginate -from .rtc_offer import rtc_offer, RtcOffer, PipelineEvent +from reflector.logger import logger +from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent import asyncio from typing import Optional @@ -46,6 +47,7 @@ class Transcript(BaseModel): def add_event(self, event: str, data): self.events.append(TranscriptEvent(event=event, data=data)) + return {"event": event, "data": data} def upsert_topic(self, topic: TranscriptTopic): existing_topic = next((t for t in self.topics if t.id == topic.id), None) @@ -239,14 +241,37 @@ async def handle_rtc_event(event: PipelineEvent, args, data): # OFC the current implementation is not good, # but it's just a POC before persistence. It won't query the # transcript from the database for each event. - print(f"Event: {event}", args, data) + # print(f"Event: {event}", args, data) transcript_id = args transcript = transcripts_controller.get_by_id(transcript_id) if not transcript: return - transcript.add_event(event=event, data=data) - if event == PipelineEvent.TOPIC: - transcript.upsert_topic(TranscriptTopic(**data)) + + # event send to websocket clients may not be the same as the event + # received from the pipeline. For example, the pipeline will send + # a TRANSCRIPT event with all words, but this is not what we want + # to send to the websocket client. + + # FIXME don't do copy + if event == PipelineEvent.TRANSCRIPT: + resp = transcript.add_event(event=event, data={ + "text": data.text, + }) + elif event == PipelineEvent.TOPIC: + topic = TranscriptTopic( + title=data.title, + summary=data.summary, + transcript=data.transcript, + timestamp=data.timestamp, + ) + resp = transcript.add_event(event=event, data=topic.model_dump()) + transcript.upsert_topic(topic) + else: + logger.warning(f"Unknown event: {event}") + return + + # transmit to websocket clients + await ws_manager.send_json(transcript_id, resp) @router.post("/transcripts/{transcript_id}/record/webrtc") @@ -261,9 +286,9 @@ async def transcript_record_webrtc( raise HTTPException(status_code=400, detail="Transcript is locked") # FIXME do not allow multiple recording at the same time - return await rtc_offer( + return await rtc_offer_base( params, request, - event_callback=transcript.handle_event, + event_callback=handle_rtc_event, event_callback_args=transcript_id, ) diff --git a/server/tests/records/test_short.wav b/server/tests/records/test_short.wav new file mode 100644 index 00000000..ca3026c9 Binary files /dev/null and b/server/tests/records/test_short.wav differ diff --git a/server/tests/test_transcripts.py b/server/tests/test_transcripts.py index 58ab8393..77cb4b23 100644 --- a/server/tests/test_transcripts.py +++ b/server/tests/test_transcripts.py @@ -61,9 +61,9 @@ async def test_transcripts_list(): assert "testxx1" in names assert "testxx2" in names + @pytest.mark.asyncio async def test_transcript_delete(): - async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await ac.post("/transcripts", json={"name": "testdel1"}) assert response.status_code == 200 @@ -76,4 +76,3 @@ async def test_transcript_delete(): response = await ac.get(f"/transcripts/{tid}") assert response.status_code == 404 - diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py new file mode 100644 index 00000000..e6280c94 --- /dev/null +++ b/server/tests/test_transcripts_rtc_ws.py @@ -0,0 +1,150 @@ +# === further tests +# FIXME test status of transcript +# FIXME test websocket connection after RTC is finished still send the full events +# FIXME try with locked session, RTC should not work + +import pytest +from unittest.mock import patch +from httpx import AsyncClient + +from reflector.app import app +from uvicorn import Config, Server +import threading +import asyncio +from pathlib import Path +from httpx_ws import aconnect_ws + + +class ThreadedUvicorn: + def __init__(self, config: Config): + self.server = Server(config) + self.thread = threading.Thread(daemon=True, target=self.server.run) + + async def start(self): + self.thread.start() + while not self.server.started: + await asyncio.sleep(0.1) + + def stop(self): + if self.thread.is_alive(): + self.server.should_exit = True + while self.thread.is_alive(): + continue + + +@pytest.fixture +async def dummy_transcript(): + from reflector.processors.audio_transcript import AudioTranscriptProcessor + from reflector.processors.types import AudioFile, Transcript, Word + + class TestAudioTranscriptProcessor(AudioTranscriptProcessor): + async def _transcript(self, data: AudioFile): + return Transcript( + text="Hello world", + words=[ + Word(start=0.0, end=1.0, text="Hello"), + Word(start=1.0, end=2.0, text="world"), + ], + ) + + with patch( + "reflector.processors.audio_transcript_auto" + ".AudioTranscriptAutoProcessor.get_instance" + ) as mock_audio: + mock_audio.return_value = TestAudioTranscriptProcessor() + yield + + +@pytest.fixture +async def dummy_llm(): + from reflector.llm.base import LLM + + class TestLLM(LLM): + async def _generate(self, prompt: str, **kwargs): + return {"text": "LLM RESULT"} + + with patch("reflector.llm.base.LLM.get_instance") as mock_llm: + mock_llm.return_value = TestLLM() + yield + + +@pytest.mark.asyncio +async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): + # goal: start the server, exchange RTC, receive websocket events + # because of that, we need to start the server in a thread + # to be able to connect with aiortc + + # start server + host = "127.0.0.1" + port = 1255 + base_url = f"http://{host}:{port}/v1" + config = Config(app=app, host=host, port=port) + server = ThreadedUvicorn(config) + await server.start() + + # create a transcript + ac = AsyncClient(base_url=base_url) + response = await ac.post("/transcripts", json={"name": "Test RTC"}) + assert response.status_code == 200 + tid = response.json()["id"] + + # create a websocket connection as a task + events = [] + + async def websocket_task(): + print("Test websocket: TASK STARTED") + async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws: + print("Test websocket: CONNECTED") + try: + while True: + msg = await ws.receive_json() + print(f"Test websocket: JSON {msg}") + if msg is None: + break + events.append(msg) + except Exception as e: + print(f"Test websocket: EXCEPTION {e}") + finally: + ws.close() + print("Test websocket: DISCONNECTED") + + websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + + # create stream client + import argparse + from reflector.stream_client import StreamClient + from aiortc.contrib.signaling import add_signaling_arguments, create_signaling + + parser = argparse.ArgumentParser() + add_signaling_arguments(parser) + args = parser.parse_args(["-s", "tcp-socket"]) + signaling = create_signaling(args) + + url = f"{base_url}/transcripts/{tid}/record/webrtc" + path = Path(__file__).parent / "records" / "test_short.wav" + client = StreamClient(signaling, url=url, play_from=path.as_posix()) + await client.start() + + timeout = 20 + while not client.is_ended(): + await asyncio.sleep(1) + timeout -= 1 + if timeout < 0: + raise TimeoutError("Timeout while waiting for RTC to end") + + await client.stop() + + # wait the processing to finish + await asyncio.sleep(2) + + # stop websocket task + websocket_task.cancel() + + # check events + print(events) + assert len(events) > 0 + assert events[0]["event"] == "TRANSCRIPT" + assert events[0]["data"]["text"] == "Hello world" + + # stop server + # server.stop()