api: implement first server API + tests

This commit is contained in:
Mathieu Virbel
2023-08-04 18:16:37 +02:00
parent 20767fde3f
commit 96f52c631a
8 changed files with 327 additions and 23 deletions

123
server/poetry.lock generated
View File

@@ -682,6 +682,78 @@ humanfriendly = ">=9.1"
[package.extras]
cron = ["capturer (>=2.4)"]
[[package]]
name = "coverage"
version = "7.2.7"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"},
{file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"},
{file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"},
{file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"},
{file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"},
{file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"},
{file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"},
{file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"},
{file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"},
{file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"},
{file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"},
{file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"},
{file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"},
{file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"},
{file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"},
{file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"},
{file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"},
{file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"},
{file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"},
{file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"},
{file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"},
{file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"},
{file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"},
{file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"},
{file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"},
{file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"},
{file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"},
{file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"},
{file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"},
{file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"},
{file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"},
{file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"},
{file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"},
{file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"},
{file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"},
{file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"},
{file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"},
{file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"},
{file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"},
{file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"},
{file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"},
{file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"},
{file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"},
{file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"},
{file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"},
{file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"},
{file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"},
{file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"},
{file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"},
{file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"},
{file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"},
{file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"},
{file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"},
{file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"},
{file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"},
{file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"},
{file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"},
{file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"},
{file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"},
{file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"},
]
[package.extras]
toml = ["tomli"]
[[package]]
name = "cryptography"
version = "41.0.2"
@@ -1169,6 +1241,23 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
[[package]]
name = "httpx-ws"
version = "0.4.1"
description = "WebSockets support for HTTPX"
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx_ws-0.4.1-py3-none-any.whl", hash = "sha256:01bdaeb66add8196485dc39912abd0a3e95b67c244aededc151156ac6adca850"},
{file = "httpx_ws-0.4.1.tar.gz", hash = "sha256:5f3e291e8fb99c89f994329d883e5679d02a0b5b12a1e414f7f8630c276b6744"},
]
[package.dependencies]
anyio = "*"
httpcore = ">=0.17.3,<0.18"
httpx = ">=0.23.1"
wsproto = "*"
[[package]]
name = "huggingface-hub"
version = "0.16.4"
@@ -1883,6 +1972,24 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "pytest-cov"
version = "4.1.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
{file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
]
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
[[package]]
name = "python-dateutil"
version = "2.8.2"
@@ -2537,6 +2644,20 @@ files = [
{file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"},
]
[[package]]
name = "wsproto"
version = "1.2.0"
description = "WebSockets state-machine based protocol implementation"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"},
{file = "wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065"},
]
[package.dependencies]
h11 = ">=0.9.0,<1"
[[package]]
name = "yarl"
version = "1.9.2"
@@ -2627,4 +2748,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "a51d7d26b88683875685ede2298f0f02ab42b1f303657b47e0a5dee9be0dc9e6"
content-hash = "c984979825947f67fc42e4553d5ff347f2f9194e4acccdeacb67a406a332009a"

View File

@@ -35,9 +35,11 @@ stamina = "^23.1.0"
[tool.poetry.group.tests.dependencies]
pytest-cov = "^4.1.0"
pytest-aiohttp = "^1.0.4"
pytest-asyncio = "^0.21.1"
pytest = "^7.4.0"
httpx-ws = "^0.4.1"
[tool.poetry.group.aws.dependencies]
@@ -46,3 +48,13 @@ aioboto3 = "^11.2.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.coverage.run]
source = ["reflector"]
[tool.pytest.ini_options]
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"]
asyncio_mode = "auto"

View File

@@ -1,9 +1,8 @@
from dataclasses import dataclass
from pydantic import BaseModel
from pathlib import Path
@dataclass
class AudioFile:
class AudioFile(BaseModel):
path: Path
sample_rate: int
channels: int
@@ -14,15 +13,13 @@ class AudioFile:
self.path.unlink()
@dataclass
class Word:
class Word(BaseModel):
text: str
start: float
end: float
@dataclass
class Transcript:
class Transcript(BaseModel):
text: str = ""
words: list[Word] = None
@@ -59,8 +56,7 @@ class Transcript:
return Transcript(text=self.text, words=words)
@dataclass
class TitleSummary:
class TitleSummary(BaseModel):
title: str
summary: str
timestamp: float
@@ -75,7 +71,6 @@ class TitleSummary:
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
@dataclass
class FinalSummary:
class FinalSummary(BaseModel):
summary: str
duration: float

View File

@@ -3,7 +3,6 @@ import time
import uuid
import httpx
import pyaudio
import stamina
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaPlayer, MediaRelay
@@ -24,7 +23,6 @@ class StreamClient:
self.server_url = url
self.play_from = play_from
self.ping_pong = ping_pong
self.paudio = pyaudio.PyAudio()
self.pc = RTCPeerConnection()
@@ -87,6 +85,7 @@ class StreamClient:
self.logger.info(f"Track {track.kind} ended")
self.pc.addTrack(audio)
self.track_audio = audio
channel = pc.createDataChannel("data-channel")
self.logger = self.logger.bind(channel=channel.label)
@@ -142,3 +141,6 @@ class StreamClient:
coro = self.run_offer(self.pc, self.signaling)
task = asyncio.create_task(coro)
await task
def is_ended(self):
return self.track_audio is None or self.track_audio.readyState == "ended"

View File

@@ -3,7 +3,8 @@ from pydantic import BaseModel, Field
from uuid import UUID, uuid4
from datetime import datetime
from fastapi_pagination import Page, paginate
from .rtc_offer import rtc_offer, RtcOffer, PipelineEvent
from reflector.logger import logger
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
import asyncio
from typing import Optional
@@ -46,6 +47,7 @@ class Transcript(BaseModel):
def add_event(self, event: str, data):
self.events.append(TranscriptEvent(event=event, data=data))
return {"event": event, "data": data}
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
@@ -239,14 +241,37 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
print(f"Event: {event}", args, data)
# print(f"Event: {event}", args, data)
transcript_id = args
transcript = transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
transcript.add_event(event=event, data=data)
if event == PipelineEvent.TOPIC:
transcript.upsert_topic(TranscriptTopic(**data))
# event send to websocket clients may not be the same as the event
# received from the pipeline. For example, the pipeline will send
# a TRANSCRIPT event with all words, but this is not what we want
# to send to the websocket client.
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data={
"text": data.text,
})
elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
transcript=data.transcript,
timestamp=data.timestamp,
)
resp = transcript.add_event(event=event, data=topic.model_dump())
transcript.upsert_topic(topic)
else:
logger.warning(f"Unknown event: {event}")
return
# transmit to websocket clients
await ws_manager.send_json(transcript_id, resp)
@router.post("/transcripts/{transcript_id}/record/webrtc")
@@ -261,9 +286,9 @@ async def transcript_record_webrtc(
raise HTTPException(status_code=400, detail="Transcript is locked")
# FIXME do not allow multiple recording at the same time
return await rtc_offer(
return await rtc_offer_base(
params,
request,
event_callback=transcript.handle_event,
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
)

Binary file not shown.

View File

@@ -61,9 +61,9 @@ async def test_transcripts_list():
assert "testxx1" in names
assert "testxx2" in names
@pytest.mark.asyncio
async def test_transcript_delete():
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200
@@ -76,4 +76,3 @@ async def test_transcript_delete():
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 404

View File

@@ -0,0 +1,150 @@
# === further tests
# FIXME test status of transcript
# FIXME test websocket connection after RTC is finished still send the full events
# FIXME try with locked session, RTC should not work
import pytest
from unittest.mock import patch
from httpx import AsyncClient
from reflector.app import app
from uvicorn import Config, Server
import threading
import asyncio
from pathlib import Path
from httpx_ws import aconnect_ws
class ThreadedUvicorn:
def __init__(self, config: Config):
self.server = Server(config)
self.thread = threading.Thread(daemon=True, target=self.server.run)
async def start(self):
self.thread.start()
while not self.server.started:
await asyncio.sleep(0.1)
def stop(self):
if self.thread.is_alive():
self.server.should_exit = True
while self.thread.is_alive():
continue
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
return Transcript(
text="Hello world",
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text="world"),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
return {"text": "LLM RESULT"}
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
# start server
host = "127.0.0.1"
port = 1255
base_url = f"http://{host}:{port}/v1"
config = Config(app=app, host=host, port=port)
server = ThreadedUvicorn(config)
await server.start()
# create a transcript
ac = AsyncClient(base_url=base_url)
response = await ac.post("/transcripts", json={"name": "Test RTC"})
assert response.status_code == 200
tid = response.json()["id"]
# create a websocket connection as a task
events = []
async def websocket_task():
print("Test websocket: TASK STARTED")
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
ws.close()
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
# create stream client
import argparse
from reflector.stream_client import StreamClient
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
timeout = 20
while not client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
await client.stop()
# wait the processing to finish
await asyncio.sleep(2)
# stop websocket task
websocket_task.cancel()
# check events
print(events)
assert len(events) > 0
assert events[0]["event"] == "TRANSCRIPT"
assert events[0]["data"]["text"] == "Hello world"
# stop server
# server.stop()