From 8397ca8c294f3186fe0552f93439cb23f1d07b49 Mon Sep 17 00:00:00 2001 From: Jose B Date: Tue, 8 Aug 2023 05:42:23 -0500 Subject: [PATCH 1/5] disable undesired features --- www/app/components/record.js | 6 +++++- www/app/components/webrtc.js | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/www/app/components/record.js b/www/app/components/record.js index d1a71093..f683f2b7 100644 --- a/www/app/components/record.js +++ b/www/app/components/record.js @@ -97,7 +97,11 @@ export default function Recorder(props) { document.getElementById("play-btn").disabled = false; } else { const stream = await navigator.mediaDevices.getUserMedia({ - audio: { deviceId }, + audio: { + deviceId, + noiseSuppression: false, + echoCancellation: false, + }, }); await record.startRecording(stream); props.setStream(stream); diff --git a/www/app/components/webrtc.js b/www/app/components/webrtc.js index 62bebb0d..f78608cd 100644 --- a/www/app/components/webrtc.js +++ b/www/app/components/webrtc.js @@ -2,7 +2,8 @@ import { useEffect, useState } from "react"; import Peer from "simple-peer"; // allow customization of the WebRTC server URL from env -const WEBRTC_SERVER_URL = process.env.NEXT_PUBLIC_WEBRTC_SERVER_URL || "http://127.0.0.1:1250/offer"; +const WEBRTC_SERVER_URL = + process.env.NEXT_PUBLIC_WEBRTC_SERVER_URL || "http://127.0.0.1:1250/offer"; const useWebRTC = (stream) => { const [data, setData] = useState({ From 93564bfd894f32419d8761dbbbc8f9eec45809e0 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 8 Aug 2023 18:31:39 +0200 Subject: [PATCH 2/5] server: fix stamina missing for old server --- server/poetry.lock | 36 +++++++++++++++++++++++++++++++++++- server/pyproject.toml | 1 + 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/server/poetry.lock b/server/poetry.lock index 17ac3efa..b0ded524 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2199,6 +2199,26 @@ files = [ {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, ] +[[package]] +name = "stamina" +version = "23.1.0" +description = "Production-grade retries made easy." +optional = false +python-versions = ">=3.8" +files = [ + {file = "stamina-23.1.0-py3-none-any.whl", hash = "sha256:850de8c2c2469aabf42a4c02e7372eaa12c2eced78f2bfa34162b8676c2846e5"}, + {file = "stamina-23.1.0.tar.gz", hash = "sha256:b16ce3d52d658aa75db813fc6a6661b770abfea915f72cda48e325f2a7854786"}, +] + +[package.dependencies] +tenacity = "*" + +[package.extras] +dev = ["nox", "prometheus-client", "stamina[tests,typing]", "structlog", "tomli"] +docs = ["furo", "myst-parser", "prometheus-client", "sphinx", "sphinx-notfound-page", "structlog"] +tests = ["pytest", "pytest-asyncio"] +typing = ["mypy (>=1.4)"] + [[package]] name = "starlette" version = "0.27.0" @@ -2247,6 +2267,20 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tenacity" +version = "8.2.2" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.6" +files = [ + {file = "tenacity-8.2.2-py3-none-any.whl", hash = "sha256:2f277afb21b851637e8f52e6a613ff08734c347dc19ade928e519d7d2d8569b0"}, + {file = "tenacity-8.2.2.tar.gz", hash = "sha256:43af037822bd0029025877f3b2d97cc4d7bb0c2991000a3d59d71517c5c969e0"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + [[package]] name = "tokenizers" version = "0.13.3" @@ -2732,4 +2766,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "0c81bfdc623dc7a55ac16a0948bfb5b2d9391abd32bad0e665b0251169c7f7de" +content-hash = "75afc46634677cd9afdf2ae66b320a8eaaa36d360d0ba187e5974b90810df44f" diff --git a/server/pyproject.toml b/server/pyproject.toml index 332df82b..039e1f5a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -27,6 +27,7 @@ fastapi-pagination = "^0.12.6" [tool.poetry.group.dev.dependencies] black = "^23.7.0" +stamina = "^23.1.0" [tool.poetry.group.client.dependencies] From 7f807c8f5f40de62fbcf61c01fffb7245e8c85ec Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 8 Aug 2023 19:32:20 +0200 Subject: [PATCH 3/5] server: implement FINAL_SUMMARY for websocket + update tests and fix flush --- server/reflector/processors/base.py | 6 +++ server/reflector/stream_client.py | 4 +- server/reflector/views/rtc_offer.py | 60 +++++++++++++++---------- server/reflector/views/transcripts.py | 34 +++++++++----- server/tests/test_transcripts_rtc_ws.py | 20 ++++++++- 5 files changed, 86 insertions(+), 38 deletions(-) diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 7d11590d..692a490b 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -14,6 +14,7 @@ class Processor: if callback: self.on(callback) self.uid = uuid4().hex + self.flushed = False self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) def set_pipeline(self, pipeline: "Pipeline"): @@ -65,6 +66,7 @@ class Processor: """ # logger.debug(f"{self.__class__.__name__} push") try: + self.flushed = False return await self._push(data) except Exception: self.logger.exception("Error in push") @@ -72,8 +74,12 @@ class Processor: async def flush(self): """ Flush data to this processor + Works only one time, until another push is called """ + if self.flushed: + return # logger.debug(f"{self.__class__.__name__} flush") + self.flushed = True return await self._flush() def describe(self, level=0): diff --git a/server/reflector/stream_client.py b/server/reflector/stream_client.py index 6b66ad45..b3e4d966 100644 --- a/server/reflector/stream_client.py +++ b/server/reflector/stream_client.py @@ -72,7 +72,7 @@ class StreamClient: async def on_connectionstatechange(): self.logger.info(f"Connection state is {pc.connectionState}") if pc.connectionState == "failed": - await pc.close() + await self.stop() self.pcs.discard(pc) @pc.on("track") @@ -87,7 +87,7 @@ class StreamClient: self.pc.addTrack(audio) self.track_audio = audio - channel = pc.createDataChannel("data-channel") + self.channel = channel = pc.createDataChannel("data-channel") self.logger = self.logger.bind(channel=channel.label) self.logger.info("Created by local party") diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 288153e6..3a8eb874 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -73,12 +73,16 @@ async def rtc_offer_base( # build pipeline callback async def on_transcript(transcript: Transcript): ctx.logger.info("Transcript", transcript=transcript) - result = { - "cmd": "SHOW_TRANSCRIPTION", - "text": transcript.text, - } - ctx.data_channel.send(dumps(result)) + # send to RTC + if ctx.data_channel.readyState == "open": + result = { + "cmd": "SHOW_TRANSCRIPTION", + "text": transcript.text, + } + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) if event_callback: await event_callback( event=PipelineEvent.TRANSCRIPT, @@ -86,9 +90,7 @@ async def rtc_offer_base( data=transcript, ) - async def on_topic( - summary: TitleSummary, event_callback=None, event_callback_args=None - ): + async def on_topic(summary: TitleSummary): # FIXME: make it incremental with the frontend, not send everything ctx.logger.info("Summary", summary=summary) ctx.topics.append( @@ -99,28 +101,36 @@ async def rtc_offer_base( "desc": summary.summary, } ) - result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} - ctx.data_channel.send(dumps(result)) + # send to RTC + if ctx.data_channel.readyState == "open": + result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) if event_callback: await event_callback( event=PipelineEvent.TOPIC, args=event_callback_args, data=summary ) - async def on_final_summary( - summary: FinalSummary, event_callback=None, event_callback_args=None - ): + async def on_final_summary(summary: FinalSummary): ctx.logger.info("FinalSummary", final_summary=summary) - result = { - "cmd": "DISPLAY_FINAL_SUMMARY", - "summary": summary.summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) + # send to RTC + if ctx.data_channel.readyState == "open": + result = { + "cmd": "DISPLAY_FINAL_SUMMARY", + "summary": summary.summary, + "duration": summary.duration, + } + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) if event_callback: await event_callback( - event=PipelineEvent.TOPIC, args=event_callback_args, data=summary + event=PipelineEvent.FINAL_SUMMARY, + args=event_callback_args, + data=summary, ) # create a context for the whole rtc transaction @@ -137,11 +147,11 @@ async def rtc_offer_base( # handle RTC peer connection pc = RTCPeerConnection() - async def flush_pipeline_and_quit(): - ctx.logger.info("Flushing pipeline") + async def flush_pipeline_and_quit(close=True): await ctx.pipeline.flush() - ctx.logger.debug("Closing peer connection") - await pc.close() + if close: + ctx.logger.debug("Closing peer connection") + await pc.close() @pc.on("datachannel") def on_datachannel(channel): @@ -164,6 +174,8 @@ async def rtc_offer_base( ctx.logger.info(f"Connection state: {pc.connectionState}") if pc.connectionState == "failed": await pc.close() + elif pc.connectionState == "closed": + await flush_pipeline_and_quit(close=False) @pc.on("track") def on_track(track): diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 120c3ff1..6cd265ce 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -21,6 +21,10 @@ def generate_transcript_name(): return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +class TranscriptText(BaseModel): + text: str + + class TranscriptTopic(BaseModel): id: UUID = Field(default_factory=uuid4) title: str @@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel): timestamp: float +class TranscriptFinalSummary(BaseModel): + summary: str + + class TranscriptEvent(BaseModel): event: str data: dict @@ -45,9 +53,10 @@ class Transcript(BaseModel): topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] - def add_event(self, event: str, data): - self.events.append(TranscriptEvent(event=event, data=data)) - return {"event": event, "data": data} + def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: + ev = TranscriptEvent(event=event, data=data.model_dump()) + self.events.append(ev) + return ev def upsert_topic(self, topic: TranscriptTopic): existing_topic = next((t for t in self.topics if t.id == topic.id), None) @@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket) # on first connection, send all events for event in transcript.events: - await websocket.send_json(event.model_dump()) + await websocket.send_json(event.model_dump(mode="json")) # XXX if transcript is final (locked=True and status=ended) # XXX send a final event to the client and close the connection @@ -254,24 +263,29 @@ async def handle_rtc_event(event: PipelineEvent, args, data): # FIXME don't do copy if event == PipelineEvent.TRANSCRIPT: - resp = transcript.add_event(event=event, data={ - "text": data.text, - }) + resp = transcript.add_event(event=event, data=TranscriptText(text=data.text)) + elif event == PipelineEvent.TOPIC: topic = TranscriptTopic( title=data.title, summary=data.summary, - transcript=data.transcript, + transcript=data.transcript.text, timestamp=data.timestamp, ) - resp = transcript.add_event(event=event, data=topic.model_dump()) + resp = transcript.add_event(event=event, data=topic) transcript.upsert_topic(topic) + + elif event == PipelineEvent.FINAL_SUMMARY: + final_summary = TranscriptFinalSummary(summary=data.summary) + resp = transcript.add_event(event=event, data=final_summary) + transcript.summary = final_summary + else: logger.warning(f"Unknown event: {event}") return # transmit to websocket clients - await ws_manager.send_json(transcript_id, resp) + await ws_manager.send_json(transcript_id, resp.model_dump(mode="json")) @router.post("/transcripts/{transcript_id}/record/webrtc") diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index e6280c94..9724c7fe 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -4,6 +4,7 @@ # FIXME try with locked session, RTC should not work import pytest +import json from unittest.mock import patch from httpx import AsyncClient @@ -61,7 +62,7 @@ async def dummy_llm(): class TestLLM(LLM): async def _generate(self, prompt: str, **kwargs): - return {"text": "LLM RESULT"} + return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: mock_llm.return_value = TestLLM() @@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): if timeout < 0: raise TimeoutError("Timeout while waiting for RTC to end") + # XXX aiortc is long to close the connection + # instead of waiting a long time, we just send a STOP + client.channel.send(json.dumps({"cmd": "STOP"})) + + # wait the processing to finish + await asyncio.sleep(2) + await client.stop() # wait the processing to finish @@ -141,10 +149,18 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): websocket_task.cancel() # check events - print(events) assert len(events) > 0 assert events[0]["event"] == "TRANSCRIPT" assert events[0]["data"]["text"] == "Hello world" + assert events[-2]["event"] == "TOPIC" + assert events[-2]["data"]["id"] + assert events[-2]["data"]["summary"] == "LLM SUMMARY" + assert events[-2]["data"]["transcript"].startswith("Hello world") + assert events[-2]["data"]["timestamp"] == 0.0 + + assert events[-1]["event"] == "FINAL_SUMMARY" + assert events[-1]["data"]["summary"] == "LLM SUMMARY" + # stop server # server.stop() From a9e0c9aa035fb115e9a8259b1b1b3af9686b00bf Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 9 Aug 2023 11:21:48 +0200 Subject: [PATCH 4/5] server: implement status update in model and websocket --- server/reflector/models.py | 2 ++ server/reflector/views/rtc_offer.py | 19 +++++++++++++ server/reflector/views/transcripts.py | 3 ++ server/tests/test_transcripts_rtc_ws.py | 37 +++++++++++++++++++------ 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/server/reflector/models.py b/server/reflector/models.py index af04ade4..d1aaaa1e 100644 --- a/server/reflector/models.py +++ b/server/reflector/models.py @@ -199,6 +199,7 @@ class TranscriptionContext: sorted_transcripts: dict data_channel: None # FIXME logger: None + status: str def __init__(self, logger): self.transcription_text = "" @@ -206,4 +207,5 @@ class TranscriptionContext: self.incremental_responses = [] self.data_channel = None self.sorted_transcripts = SortedDict() + self.status = "idle" self.logger = logger diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 3a8eb874..cbc0a4dc 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -52,10 +52,15 @@ class RtcOffer(BaseModel): type: str +class StrValue(BaseModel): + value: str + + class PipelineEvent(StrEnum): TRANSCRIPT = "TRANSCRIPT" TOPIC = "TOPIC" FINAL_SUMMARY = "FINAL_SUMMARY" + STATUS = "STATUS" async def rtc_offer_base( @@ -70,6 +75,17 @@ async def rtc_offer_base( ctx = TranscriptionContext(logger=logger.bind(client=clientid)) ctx.topics = [] + async def update_status(status: str): + changed = ctx.status != status + if changed: + ctx.status = status + if event_callback: + await event_callback( + event=PipelineEvent.STATUS, + args=event_callback_args, + data=StrValue(value=status), + ) + # build pipeline callback async def on_transcript(transcript: Transcript): ctx.logger.info("Transcript", transcript=transcript) @@ -148,10 +164,12 @@ async def rtc_offer_base( pc = RTCPeerConnection() async def flush_pipeline_and_quit(close=True): + await update_status("processing") await ctx.pipeline.flush() if close: ctx.logger.debug("Closing peer connection") await pc.close() + await update_status("ended") @pc.on("datachannel") def on_datachannel(channel): @@ -181,6 +199,7 @@ async def rtc_offer_base( def on_track(track): ctx.logger.info(f"Track {track.kind} received") pc.addTrack(AudioStreamTrack(ctx, track)) + asyncio.get_event_loop().create_task(update_status("recording")) await pc.setRemoteDescription(offer) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 6cd265ce..8f0cf832 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -280,6 +280,9 @@ async def handle_rtc_event(event: PipelineEvent, args, data): resp = transcript.add_event(event=event, data=final_summary) transcript.summary = final_summary + elif event == PipelineEvent.STATUS: + resp = transcript.add_event(event=event, data=data) + else: logger.warning(f"Unknown event: {event}") return diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 9724c7fe..113e07cf 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -150,17 +150,36 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): # check events assert len(events) > 0 - assert events[0]["event"] == "TRANSCRIPT" - assert events[0]["data"]["text"] == "Hello world" + from pprint import pprint - assert events[-2]["event"] == "TOPIC" - assert events[-2]["data"]["id"] - assert events[-2]["data"]["summary"] == "LLM SUMMARY" - assert events[-2]["data"]["transcript"].startswith("Hello world") - assert events[-2]["data"]["timestamp"] == 0.0 + pprint(events) - assert events[-1]["event"] == "FINAL_SUMMARY" - assert events[-1]["data"]["summary"] == "LLM SUMMARY" + # get events list + eventnames = [e["event"] for e in events] + + # check events + assert "TRANSCRIPT" in eventnames + ev = events[eventnames.index("TRANSCRIPT")] + assert ev["data"]["text"] == "Hello world" + + assert "TOPIC" in eventnames + ev = events[eventnames.index("TOPIC")] + assert ev["data"]["id"] + assert ev["data"]["summary"] == "LLM SUMMARY" + assert ev["data"]["transcript"].startswith("Hello world") + assert ev["data"]["timestamp"] == 0.0 + + assert "FINAL_SUMMARY" in eventnames + ev = events[eventnames.index("FINAL_SUMMARY")] + assert ev["data"]["summary"] == "LLM SUMMARY" + + # check status order + statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] + assert statuses == ["recording", "processing", "ended"] + + # ensure the last event received is ended + assert events[-1]["event"] == "STATUS" + assert events[-1]["data"]["value"] == "ended" # stop server # server.stop() From 26e34aec2ddb8b6a0b09c8a4907cbcfd6ab708fb Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 9 Aug 2023 11:23:28 +0200 Subject: [PATCH 5/5] server: ensure transcript status model is updated + tests --- server/reflector/views/transcripts.py | 1 + server/tests/test_transcripts_rtc_ws.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 8f0cf832..dabd2f9d 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -282,6 +282,7 @@ async def handle_rtc_event(event: PipelineEvent, args, data): elif event == PipelineEvent.STATUS: resp = transcript.add_event(event=event, data=data) + transcript.status = data.value else: logger.warning(f"Unknown event: {event}") diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 113e07cf..70ee209b 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -183,3 +183,8 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): # stop server # server.stop() + + # check that transcript status in model is updated + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + assert resp.json()["status"] == "ended"