mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: implement FINAL_SUMMARY for websocket + update tests and fix flush
This commit is contained in:
@@ -14,6 +14,7 @@ class Processor:
|
||||
if callback:
|
||||
self.on(callback)
|
||||
self.uid = uuid4().hex
|
||||
self.flushed = False
|
||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
@@ -65,6 +66,7 @@ class Processor:
|
||||
"""
|
||||
# logger.debug(f"{self.__class__.__name__} push")
|
||||
try:
|
||||
self.flushed = False
|
||||
return await self._push(data)
|
||||
except Exception:
|
||||
self.logger.exception("Error in push")
|
||||
@@ -72,8 +74,12 @@ class Processor:
|
||||
async def flush(self):
|
||||
"""
|
||||
Flush data to this processor
|
||||
Works only one time, until another push is called
|
||||
"""
|
||||
if self.flushed:
|
||||
return
|
||||
# logger.debug(f"{self.__class__.__name__} flush")
|
||||
self.flushed = True
|
||||
return await self._flush()
|
||||
|
||||
def describe(self, level=0):
|
||||
|
||||
@@ -72,7 +72,7 @@ class StreamClient:
|
||||
async def on_connectionstatechange():
|
||||
self.logger.info(f"Connection state is {pc.connectionState}")
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
await self.stop()
|
||||
self.pcs.discard(pc)
|
||||
|
||||
@pc.on("track")
|
||||
@@ -87,7 +87,7 @@ class StreamClient:
|
||||
self.pc.addTrack(audio)
|
||||
self.track_audio = audio
|
||||
|
||||
channel = pc.createDataChannel("data-channel")
|
||||
self.channel = channel = pc.createDataChannel("data-channel")
|
||||
self.logger = self.logger.bind(channel=channel.label)
|
||||
self.logger.info("Created by local party")
|
||||
|
||||
|
||||
@@ -73,12 +73,16 @@ async def rtc_offer_base(
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": transcript.text,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TRANSCRIPT,
|
||||
@@ -86,9 +90,7 @@ async def rtc_offer_base(
|
||||
data=transcript,
|
||||
)
|
||||
|
||||
async def on_topic(
|
||||
summary: TitleSummary, event_callback=None, event_callback_args=None
|
||||
):
|
||||
async def on_topic(summary: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
ctx.topics.append(
|
||||
@@ -99,18 +101,23 @@ async def rtc_offer_base(
|
||||
"desc": summary.summary,
|
||||
}
|
||||
)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
||||
)
|
||||
|
||||
async def on_final_summary(
|
||||
summary: FinalSummary, event_callback=None, event_callback_args=None
|
||||
):
|
||||
async def on_final_summary(summary: FinalSummary):
|
||||
ctx.logger.info("FinalSummary", final_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||
"summary": summary.summary,
|
||||
@@ -118,9 +125,12 @@ async def rtc_offer_base(
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
||||
event=PipelineEvent.FINAL_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
@@ -137,9 +147,9 @@ async def rtc_offer_base(
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
async def flush_pipeline_and_quit():
|
||||
ctx.logger.info("Flushing pipeline")
|
||||
async def flush_pipeline_and_quit(close=True):
|
||||
await ctx.pipeline.flush()
|
||||
if close:
|
||||
ctx.logger.debug("Closing peer connection")
|
||||
await pc.close()
|
||||
|
||||
@@ -164,6 +174,8 @@ async def rtc_offer_base(
|
||||
ctx.logger.info(f"Connection state: {pc.connectionState}")
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
elif pc.connectionState == "closed":
|
||||
await flush_pipeline_and_quit(close=False)
|
||||
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
|
||||
@@ -21,6 +21,10 @@ def generate_transcript_name():
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
title: str
|
||||
@@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel):
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptFinalSummary(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
@@ -45,9 +53,10 @@ class Transcript(BaseModel):
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
|
||||
def add_event(self, event: str, data):
|
||||
self.events.append(TranscriptEvent(event=event, data=data))
|
||||
return {"event": event, "data": data}
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||
@@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket)
|
||||
|
||||
# on first connection, send all events
|
||||
for event in transcript.events:
|
||||
await websocket.send_json(event.model_dump())
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
@@ -254,24 +263,29 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
|
||||
# FIXME don't do copy
|
||||
if event == PipelineEvent.TRANSCRIPT:
|
||||
resp = transcript.add_event(event=event, data={
|
||||
"text": data.text,
|
||||
})
|
||||
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
|
||||
|
||||
elif event == PipelineEvent.TOPIC:
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
transcript=data.transcript,
|
||||
transcript=data.transcript.text,
|
||||
timestamp=data.timestamp,
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=topic.model_dump())
|
||||
resp = transcript.add_event(event=event, data=topic)
|
||||
transcript.upsert_topic(topic)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SUMMARY:
|
||||
final_summary = TranscriptFinalSummary(summary=data.summary)
|
||||
resp = transcript.add_event(event=event, data=final_summary)
|
||||
transcript.summary = final_summary
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
return
|
||||
|
||||
# transmit to websocket clients
|
||||
await ws_manager.send_json(transcript_id, resp)
|
||||
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# FIXME try with locked session, RTC should not work
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from httpx import AsyncClient
|
||||
|
||||
@@ -61,7 +62,7 @@ async def dummy_llm():
|
||||
|
||||
class TestLLM(LLM):
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
return {"text": "LLM RESULT"}
|
||||
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
@@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
if timeout < 0:
|
||||
raise TimeoutError("Timeout while waiting for RTC to end")
|
||||
|
||||
# XXX aiortc is long to close the connection
|
||||
# instead of waiting a long time, we just send a STOP
|
||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await client.stop()
|
||||
|
||||
# wait the processing to finish
|
||||
@@ -141,10 +149,18 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
websocket_task.cancel()
|
||||
|
||||
# check events
|
||||
print(events)
|
||||
assert len(events) > 0
|
||||
assert events[0]["event"] == "TRANSCRIPT"
|
||||
assert events[0]["data"]["text"] == "Hello world"
|
||||
|
||||
assert events[-2]["event"] == "TOPIC"
|
||||
assert events[-2]["data"]["id"]
|
||||
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
|
||||
assert events[-2]["data"]["transcript"].startswith("Hello world")
|
||||
assert events[-2]["data"]["timestamp"] == 0.0
|
||||
|
||||
assert events[-1]["event"] == "FINAL_SUMMARY"
|
||||
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
|
||||
|
||||
# stop server
|
||||
# server.stop()
|
||||
|
||||
Reference in New Issue
Block a user