server: implement FINAL_SUMMARY for websocket + update tests and fix flush

This commit is contained in:
Mathieu Virbel
2023-08-08 19:32:20 +02:00
parent 93564bfd89
commit 7f807c8f5f
5 changed files with 86 additions and 38 deletions

View File

@@ -14,6 +14,7 @@ class Processor:
if callback:
self.on(callback)
self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
def set_pipeline(self, pipeline: "Pipeline"):
@@ -65,6 +66,7 @@ class Processor:
"""
# logger.debug(f"{self.__class__.__name__} push")
try:
self.flushed = False
return await self._push(data)
except Exception:
self.logger.exception("Error in push")
@@ -72,8 +74,12 @@ class Processor:
async def flush(self):
"""
Flush data to this processor
Works only one time, until another push is called
"""
if self.flushed:
return
# logger.debug(f"{self.__class__.__name__} flush")
self.flushed = True
return await self._flush()
def describe(self, level=0):

View File

@@ -72,7 +72,7 @@ class StreamClient:
async def on_connectionstatechange():
self.logger.info(f"Connection state is {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
await self.stop()
self.pcs.discard(pc)
@pc.on("track")
@@ -87,7 +87,7 @@ class StreamClient:
self.pc.addTrack(audio)
self.track_audio = audio
channel = pc.createDataChannel("data-channel")
self.channel = channel = pc.createDataChannel("data-channel")
self.logger = self.logger.bind(channel=channel.label)
self.logger.info("Created by local party")

View File

@@ -73,12 +73,16 @@ async def rtc_offer_base(
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TRANSCRIPT,
@@ -86,9 +90,7 @@ async def rtc_offer_base(
data=transcript,
)
async def on_topic(
summary: TitleSummary, event_callback=None, event_callback_args=None
):
async def on_topic(summary: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary)
ctx.topics.append(
@@ -99,28 +101,36 @@ async def rtc_offer_base(
"desc": summary.summary,
}
)
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
)
async def on_final_summary(
summary: FinalSummary, event_callback=None, event_callback_args=None
):
async def on_final_summary(summary: FinalSummary):
ctx.logger.info("FinalSummary", final_summary=summary)
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
event=PipelineEvent.FINAL_SUMMARY,
args=event_callback_args,
data=summary,
)
# create a context for the whole rtc transaction
@@ -137,11 +147,11 @@ async def rtc_offer_base(
# handle RTC peer connection
pc = RTCPeerConnection()
async def flush_pipeline_and_quit():
ctx.logger.info("Flushing pipeline")
async def flush_pipeline_and_quit(close=True):
await ctx.pipeline.flush()
ctx.logger.debug("Closing peer connection")
await pc.close()
if close:
ctx.logger.debug("Closing peer connection")
await pc.close()
@pc.on("datachannel")
def on_datachannel(channel):
@@ -164,6 +174,8 @@ async def rtc_offer_base(
ctx.logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
elif pc.connectionState == "closed":
await flush_pipeline_and_quit(close=False)
@pc.on("track")
def on_track(track):

View File

@@ -21,6 +21,10 @@ def generate_transcript_name():
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class TranscriptText(BaseModel):
text: str
class TranscriptTopic(BaseModel):
id: UUID = Field(default_factory=uuid4)
title: str
@@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel):
timestamp: float
class TranscriptFinalSummary(BaseModel):
summary: str
class TranscriptEvent(BaseModel):
event: str
data: dict
@@ -45,9 +53,10 @@ class Transcript(BaseModel):
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
def add_event(self, event: str, data):
self.events.append(TranscriptEvent(event=event, data=data))
return {"event": event, "data": data}
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
@@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket)
# on first connection, send all events
for event in transcript.events:
await websocket.send_json(event.model_dump())
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
@@ -254,24 +263,29 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data={
"text": data.text,
})
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
transcript=data.transcript,
transcript=data.transcript.text,
timestamp=data.timestamp,
)
resp = transcript.add_event(event=event, data=topic.model_dump())
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)
elif event == PipelineEvent.FINAL_SUMMARY:
final_summary = TranscriptFinalSummary(summary=data.summary)
resp = transcript.add_event(event=event, data=final_summary)
transcript.summary = final_summary
else:
logger.warning(f"Unknown event: {event}")
return
# transmit to websocket clients
await ws_manager.send_json(transcript_id, resp)
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc")

View File

@@ -4,6 +4,7 @@
# FIXME try with locked session, RTC should not work
import pytest
import json
from unittest.mock import patch
from httpx import AsyncClient
@@ -61,7 +62,7 @@ async def dummy_llm():
class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
return {"text": "LLM RESULT"}
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
@@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
@@ -141,10 +149,18 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
websocket_task.cancel()
# check events
print(events)
assert len(events) > 0
assert events[0]["event"] == "TRANSCRIPT"
assert events[0]["data"]["text"] == "Hello world"
assert events[-2]["event"] == "TOPIC"
assert events[-2]["data"]["id"]
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
assert events[-2]["data"]["transcript"].startswith("Hello world")
assert events[-2]["data"]["timestamp"] == 0.0
assert events[-1]["event"] == "FINAL_SUMMARY"
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
# stop server
# server.stop()