server: implement FINAL_SUMMARY for websocket + update tests and fix flush

This commit is contained in:
Mathieu Virbel
2023-08-08 19:32:20 +02:00
parent 93564bfd89
commit 7f807c8f5f
5 changed files with 86 additions and 38 deletions

View File

@@ -14,6 +14,7 @@ class Processor:
if callback: if callback:
self.on(callback) self.on(callback)
self.uid = uuid4().hex self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
def set_pipeline(self, pipeline: "Pipeline"): def set_pipeline(self, pipeline: "Pipeline"):
@@ -65,6 +66,7 @@ class Processor:
""" """
# logger.debug(f"{self.__class__.__name__} push") # logger.debug(f"{self.__class__.__name__} push")
try: try:
self.flushed = False
return await self._push(data) return await self._push(data)
except Exception: except Exception:
self.logger.exception("Error in push") self.logger.exception("Error in push")
@@ -72,8 +74,12 @@ class Processor:
async def flush(self): async def flush(self):
""" """
Flush data to this processor Flush data to this processor
Works only one time, until another push is called
""" """
if self.flushed:
return
# logger.debug(f"{self.__class__.__name__} flush") # logger.debug(f"{self.__class__.__name__} flush")
self.flushed = True
return await self._flush() return await self._flush()
def describe(self, level=0): def describe(self, level=0):

View File

@@ -72,7 +72,7 @@ class StreamClient:
async def on_connectionstatechange(): async def on_connectionstatechange():
self.logger.info(f"Connection state is {pc.connectionState}") self.logger.info(f"Connection state is {pc.connectionState}")
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await self.stop()
self.pcs.discard(pc) self.pcs.discard(pc)
@pc.on("track") @pc.on("track")
@@ -87,7 +87,7 @@ class StreamClient:
self.pc.addTrack(audio) self.pc.addTrack(audio)
self.track_audio = audio self.track_audio = audio
channel = pc.createDataChannel("data-channel") self.channel = channel = pc.createDataChannel("data-channel")
self.logger = self.logger.bind(channel=channel.label) self.logger = self.logger.bind(channel=channel.label)
self.logger.info("Created by local party") self.logger.info("Created by local party")

View File

@@ -73,12 +73,16 @@ async def rtc_offer_base(
# build pipeline callback # build pipeline callback
async def on_transcript(transcript: Transcript): async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript) ctx.logger.info("Transcript", transcript=transcript)
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback: if event_callback:
await event_callback( await event_callback(
event=PipelineEvent.TRANSCRIPT, event=PipelineEvent.TRANSCRIPT,
@@ -86,9 +90,7 @@ async def rtc_offer_base(
data=transcript, data=transcript,
) )
async def on_topic( async def on_topic(summary: TitleSummary):
summary: TitleSummary, event_callback=None, event_callback_args=None
):
# FIXME: make it incremental with the frontend, not send everything # FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary) ctx.logger.info("Summary", summary=summary)
ctx.topics.append( ctx.topics.append(
@@ -99,28 +101,36 @@ async def rtc_offer_base(
"desc": summary.summary, "desc": summary.summary,
} }
) )
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback: if event_callback:
await event_callback( await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
) )
async def on_final_summary( async def on_final_summary(summary: FinalSummary):
summary: FinalSummary, event_callback=None, event_callback_args=None
):
ctx.logger.info("FinalSummary", final_summary=summary) ctx.logger.info("FinalSummary", final_summary=summary)
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback: if event_callback:
await event_callback( await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary event=PipelineEvent.FINAL_SUMMARY,
args=event_callback_args,
data=summary,
) )
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
@@ -137,11 +147,11 @@ async def rtc_offer_base(
# handle RTC peer connection # handle RTC peer connection
pc = RTCPeerConnection() pc = RTCPeerConnection()
async def flush_pipeline_and_quit(): async def flush_pipeline_and_quit(close=True):
ctx.logger.info("Flushing pipeline")
await ctx.pipeline.flush() await ctx.pipeline.flush()
ctx.logger.debug("Closing peer connection") if close:
await pc.close() ctx.logger.debug("Closing peer connection")
await pc.close()
@pc.on("datachannel") @pc.on("datachannel")
def on_datachannel(channel): def on_datachannel(channel):
@@ -164,6 +174,8 @@ async def rtc_offer_base(
ctx.logger.info(f"Connection state: {pc.connectionState}") ctx.logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
elif pc.connectionState == "closed":
await flush_pipeline_and_quit(close=False)
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):

View File

@@ -21,6 +21,10 @@ def generate_transcript_name():
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class TranscriptText(BaseModel):
text: str
class TranscriptTopic(BaseModel): class TranscriptTopic(BaseModel):
id: UUID = Field(default_factory=uuid4) id: UUID = Field(default_factory=uuid4)
title: str title: str
@@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel):
timestamp: float timestamp: float
class TranscriptFinalSummary(BaseModel):
summary: str
class TranscriptEvent(BaseModel): class TranscriptEvent(BaseModel):
event: str event: str
data: dict data: dict
@@ -45,9 +53,10 @@ class Transcript(BaseModel):
topics: list[TranscriptTopic] = [] topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = [] events: list[TranscriptEvent] = []
def add_event(self, event: str, data): def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
self.events.append(TranscriptEvent(event=event, data=data)) ev = TranscriptEvent(event=event, data=data.model_dump())
return {"event": event, "data": data} self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic): def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None) existing_topic = next((t for t in self.topics if t.id == topic.id), None)
@@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket)
# on first connection, send all events # on first connection, send all events
for event in transcript.events: for event in transcript.events:
await websocket.send_json(event.model_dump()) await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended) # XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection # XXX send a final event to the client and close the connection
@@ -254,24 +263,29 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# FIXME don't do copy # FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT: if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data={ resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
"text": data.text,
})
elif event == PipelineEvent.TOPIC: elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic( topic = TranscriptTopic(
title=data.title, title=data.title,
summary=data.summary, summary=data.summary,
transcript=data.transcript, transcript=data.transcript.text,
timestamp=data.timestamp, timestamp=data.timestamp,
) )
resp = transcript.add_event(event=event, data=topic.model_dump()) resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
elif event == PipelineEvent.FINAL_SUMMARY:
final_summary = TranscriptFinalSummary(summary=data.summary)
resp = transcript.add_event(event=event, data=final_summary)
transcript.summary = final_summary
else: else:
logger.warning(f"Unknown event: {event}") logger.warning(f"Unknown event: {event}")
return return
# transmit to websocket clients # transmit to websocket clients
await ws_manager.send_json(transcript_id, resp) await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc") @router.post("/transcripts/{transcript_id}/record/webrtc")

View File

@@ -4,6 +4,7 @@
# FIXME try with locked session, RTC should not work # FIXME try with locked session, RTC should not work
import pytest import pytest
import json
from unittest.mock import patch from unittest.mock import patch
from httpx import AsyncClient from httpx import AsyncClient
@@ -61,7 +62,7 @@ async def dummy_llm():
class TestLLM(LLM): class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, **kwargs):
return {"text": "LLM RESULT"} return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm: with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM() mock_llm.return_value = TestLLM()
@@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
if timeout < 0: if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end") raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop() await client.stop()
# wait the processing to finish # wait the processing to finish
@@ -141,10 +149,18 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
websocket_task.cancel() websocket_task.cancel()
# check events # check events
print(events)
assert len(events) > 0 assert len(events) > 0
assert events[0]["event"] == "TRANSCRIPT" assert events[0]["event"] == "TRANSCRIPT"
assert events[0]["data"]["text"] == "Hello world" assert events[0]["data"]["text"] == "Hello world"
assert events[-2]["event"] == "TOPIC"
assert events[-2]["data"]["id"]
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
assert events[-2]["data"]["transcript"].startswith("Hello world")
assert events[-2]["data"]["timestamp"] == 0.0
assert events[-1]["event"] == "FINAL_SUMMARY"
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
# stop server # stop server
# server.stop() # server.stop()