From a9e0c9aa035fb115e9a8259b1b1b3af9686b00bf Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 9 Aug 2023 11:21:48 +0200 Subject: [PATCH] server: implement status update in model and websocket --- server/reflector/models.py | 2 ++ server/reflector/views/rtc_offer.py | 19 +++++++++++++ server/reflector/views/transcripts.py | 3 ++ server/tests/test_transcripts_rtc_ws.py | 37 +++++++++++++++++++------ 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/server/reflector/models.py b/server/reflector/models.py index af04ade4..d1aaaa1e 100644 --- a/server/reflector/models.py +++ b/server/reflector/models.py @@ -199,6 +199,7 @@ class TranscriptionContext: sorted_transcripts: dict data_channel: None # FIXME logger: None + status: str def __init__(self, logger): self.transcription_text = "" @@ -206,4 +207,5 @@ class TranscriptionContext: self.incremental_responses = [] self.data_channel = None self.sorted_transcripts = SortedDict() + self.status = "idle" self.logger = logger diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 3a8eb874..cbc0a4dc 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -52,10 +52,15 @@ class RtcOffer(BaseModel): type: str +class StrValue(BaseModel): + value: str + + class PipelineEvent(StrEnum): TRANSCRIPT = "TRANSCRIPT" TOPIC = "TOPIC" FINAL_SUMMARY = "FINAL_SUMMARY" + STATUS = "STATUS" async def rtc_offer_base( @@ -70,6 +75,17 @@ async def rtc_offer_base( ctx = TranscriptionContext(logger=logger.bind(client=clientid)) ctx.topics = [] + async def update_status(status: str): + changed = ctx.status != status + if changed: + ctx.status = status + if event_callback: + await event_callback( + event=PipelineEvent.STATUS, + args=event_callback_args, + data=StrValue(value=status), + ) + # build pipeline callback async def on_transcript(transcript: Transcript): ctx.logger.info("Transcript", transcript=transcript) @@ -148,10 +164,12 @@ async def rtc_offer_base( pc = RTCPeerConnection() async def flush_pipeline_and_quit(close=True): + await update_status("processing") await ctx.pipeline.flush() if close: ctx.logger.debug("Closing peer connection") await pc.close() + await update_status("ended") @pc.on("datachannel") def on_datachannel(channel): @@ -181,6 +199,7 @@ async def rtc_offer_base( def on_track(track): ctx.logger.info(f"Track {track.kind} received") pc.addTrack(AudioStreamTrack(ctx, track)) + asyncio.get_event_loop().create_task(update_status("recording")) await pc.setRemoteDescription(offer) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 6cd265ce..8f0cf832 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -280,6 +280,9 @@ async def handle_rtc_event(event: PipelineEvent, args, data): resp = transcript.add_event(event=event, data=final_summary) transcript.summary = final_summary + elif event == PipelineEvent.STATUS: + resp = transcript.add_event(event=event, data=data) + else: logger.warning(f"Unknown event: {event}") return diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 9724c7fe..113e07cf 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -150,17 +150,36 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): # check events assert len(events) > 0 - assert events[0]["event"] == "TRANSCRIPT" - assert events[0]["data"]["text"] == "Hello world" + from pprint import pprint - assert events[-2]["event"] == "TOPIC" - assert events[-2]["data"]["id"] - assert events[-2]["data"]["summary"] == "LLM SUMMARY" - assert events[-2]["data"]["transcript"].startswith("Hello world") - assert events[-2]["data"]["timestamp"] == 0.0 + pprint(events) - assert events[-1]["event"] == "FINAL_SUMMARY" - assert events[-1]["data"]["summary"] == "LLM SUMMARY" + # get events list + eventnames = [e["event"] for e in events] + + # check events + assert "TRANSCRIPT" in eventnames + ev = events[eventnames.index("TRANSCRIPT")] + assert ev["data"]["text"] == "Hello world" + + assert "TOPIC" in eventnames + ev = events[eventnames.index("TOPIC")] + assert ev["data"]["id"] + assert ev["data"]["summary"] == "LLM SUMMARY" + assert ev["data"]["transcript"].startswith("Hello world") + assert ev["data"]["timestamp"] == 0.0 + + assert "FINAL_SUMMARY" in eventnames + ev = events[eventnames.index("FINAL_SUMMARY")] + assert ev["data"]["summary"] == "LLM SUMMARY" + + # check status order + statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] + assert statuses == ["recording", "processing", "ended"] + + # ensure the last event received is ended + assert events[-1]["event"] == "STATUS" + assert events[-1]["data"]["value"] == "ended" # stop server # server.stop()