mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: implement FINAL_SUMMARY for websocket + update tests and fix flush
This commit is contained in:
@@ -14,6 +14,7 @@ class Processor:
|
|||||||
if callback:
|
if callback:
|
||||||
self.on(callback)
|
self.on(callback)
|
||||||
self.uid = uuid4().hex
|
self.uid = uuid4().hex
|
||||||
|
self.flushed = False
|
||||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||||
|
|
||||||
def set_pipeline(self, pipeline: "Pipeline"):
|
def set_pipeline(self, pipeline: "Pipeline"):
|
||||||
@@ -65,6 +66,7 @@ class Processor:
|
|||||||
"""
|
"""
|
||||||
# logger.debug(f"{self.__class__.__name__} push")
|
# logger.debug(f"{self.__class__.__name__} push")
|
||||||
try:
|
try:
|
||||||
|
self.flushed = False
|
||||||
return await self._push(data)
|
return await self._push(data)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception("Error in push")
|
self.logger.exception("Error in push")
|
||||||
@@ -72,8 +74,12 @@ class Processor:
|
|||||||
async def flush(self):
|
async def flush(self):
|
||||||
"""
|
"""
|
||||||
Flush data to this processor
|
Flush data to this processor
|
||||||
|
Works only one time, until another push is called
|
||||||
"""
|
"""
|
||||||
|
if self.flushed:
|
||||||
|
return
|
||||||
# logger.debug(f"{self.__class__.__name__} flush")
|
# logger.debug(f"{self.__class__.__name__} flush")
|
||||||
|
self.flushed = True
|
||||||
return await self._flush()
|
return await self._flush()
|
||||||
|
|
||||||
def describe(self, level=0):
|
def describe(self, level=0):
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ class StreamClient:
|
|||||||
async def on_connectionstatechange():
|
async def on_connectionstatechange():
|
||||||
self.logger.info(f"Connection state is {pc.connectionState}")
|
self.logger.info(f"Connection state is {pc.connectionState}")
|
||||||
if pc.connectionState == "failed":
|
if pc.connectionState == "failed":
|
||||||
await pc.close()
|
await self.stop()
|
||||||
self.pcs.discard(pc)
|
self.pcs.discard(pc)
|
||||||
|
|
||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
@@ -87,7 +87,7 @@ class StreamClient:
|
|||||||
self.pc.addTrack(audio)
|
self.pc.addTrack(audio)
|
||||||
self.track_audio = audio
|
self.track_audio = audio
|
||||||
|
|
||||||
channel = pc.createDataChannel("data-channel")
|
self.channel = channel = pc.createDataChannel("data-channel")
|
||||||
self.logger = self.logger.bind(channel=channel.label)
|
self.logger = self.logger.bind(channel=channel.label)
|
||||||
self.logger.info("Created by local party")
|
self.logger.info("Created by local party")
|
||||||
|
|
||||||
|
|||||||
@@ -73,12 +73,16 @@ async def rtc_offer_base(
|
|||||||
# build pipeline callback
|
# build pipeline callback
|
||||||
async def on_transcript(transcript: Transcript):
|
async def on_transcript(transcript: Transcript):
|
||||||
ctx.logger.info("Transcript", transcript=transcript)
|
ctx.logger.info("Transcript", transcript=transcript)
|
||||||
result = {
|
|
||||||
"cmd": "SHOW_TRANSCRIPTION",
|
|
||||||
"text": transcript.text,
|
|
||||||
}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
|
# send to RTC
|
||||||
|
if ctx.data_channel.readyState == "open":
|
||||||
|
result = {
|
||||||
|
"cmd": "SHOW_TRANSCRIPTION",
|
||||||
|
"text": transcript.text,
|
||||||
|
}
|
||||||
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
|
# send to callback (eg. websocket)
|
||||||
if event_callback:
|
if event_callback:
|
||||||
await event_callback(
|
await event_callback(
|
||||||
event=PipelineEvent.TRANSCRIPT,
|
event=PipelineEvent.TRANSCRIPT,
|
||||||
@@ -86,9 +90,7 @@ async def rtc_offer_base(
|
|||||||
data=transcript,
|
data=transcript,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_topic(
|
async def on_topic(summary: TitleSummary):
|
||||||
summary: TitleSummary, event_callback=None, event_callback_args=None
|
|
||||||
):
|
|
||||||
# FIXME: make it incremental with the frontend, not send everything
|
# FIXME: make it incremental with the frontend, not send everything
|
||||||
ctx.logger.info("Summary", summary=summary)
|
ctx.logger.info("Summary", summary=summary)
|
||||||
ctx.topics.append(
|
ctx.topics.append(
|
||||||
@@ -99,28 +101,36 @@ async def rtc_offer_base(
|
|||||||
"desc": summary.summary,
|
"desc": summary.summary,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
|
# send to RTC
|
||||||
|
if ctx.data_channel.readyState == "open":
|
||||||
|
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||||
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
|
# send to callback (eg. websocket)
|
||||||
if event_callback:
|
if event_callback:
|
||||||
await event_callback(
|
await event_callback(
|
||||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_final_summary(
|
async def on_final_summary(summary: FinalSummary):
|
||||||
summary: FinalSummary, event_callback=None, event_callback_args=None
|
|
||||||
):
|
|
||||||
ctx.logger.info("FinalSummary", final_summary=summary)
|
ctx.logger.info("FinalSummary", final_summary=summary)
|
||||||
result = {
|
|
||||||
"cmd": "DISPLAY_FINAL_SUMMARY",
|
|
||||||
"summary": summary.summary,
|
|
||||||
"duration": summary.duration,
|
|
||||||
}
|
|
||||||
ctx.data_channel.send(dumps(result))
|
|
||||||
|
|
||||||
|
# send to RTC
|
||||||
|
if ctx.data_channel.readyState == "open":
|
||||||
|
result = {
|
||||||
|
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||||
|
"summary": summary.summary,
|
||||||
|
"duration": summary.duration,
|
||||||
|
}
|
||||||
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
|
# send to callback (eg. websocket)
|
||||||
if event_callback:
|
if event_callback:
|
||||||
await event_callback(
|
await event_callback(
|
||||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
event=PipelineEvent.FINAL_SUMMARY,
|
||||||
|
args=event_callback_args,
|
||||||
|
data=summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
@@ -137,11 +147,11 @@ async def rtc_offer_base(
|
|||||||
# handle RTC peer connection
|
# handle RTC peer connection
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
|
|
||||||
async def flush_pipeline_and_quit():
|
async def flush_pipeline_and_quit(close=True):
|
||||||
ctx.logger.info("Flushing pipeline")
|
|
||||||
await ctx.pipeline.flush()
|
await ctx.pipeline.flush()
|
||||||
ctx.logger.debug("Closing peer connection")
|
if close:
|
||||||
await pc.close()
|
ctx.logger.debug("Closing peer connection")
|
||||||
|
await pc.close()
|
||||||
|
|
||||||
@pc.on("datachannel")
|
@pc.on("datachannel")
|
||||||
def on_datachannel(channel):
|
def on_datachannel(channel):
|
||||||
@@ -164,6 +174,8 @@ async def rtc_offer_base(
|
|||||||
ctx.logger.info(f"Connection state: {pc.connectionState}")
|
ctx.logger.info(f"Connection state: {pc.connectionState}")
|
||||||
if pc.connectionState == "failed":
|
if pc.connectionState == "failed":
|
||||||
await pc.close()
|
await pc.close()
|
||||||
|
elif pc.connectionState == "closed":
|
||||||
|
await flush_pipeline_and_quit(close=False)
|
||||||
|
|
||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
def on_track(track):
|
def on_track(track):
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ def generate_transcript_name():
|
|||||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptText(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTopic(BaseModel):
|
class TranscriptTopic(BaseModel):
|
||||||
id: UUID = Field(default_factory=uuid4)
|
id: UUID = Field(default_factory=uuid4)
|
||||||
title: str
|
title: str
|
||||||
@@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel):
|
|||||||
timestamp: float
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptFinalSummary(BaseModel):
|
||||||
|
summary: str
|
||||||
|
|
||||||
|
|
||||||
class TranscriptEvent(BaseModel):
|
class TranscriptEvent(BaseModel):
|
||||||
event: str
|
event: str
|
||||||
data: dict
|
data: dict
|
||||||
@@ -45,9 +53,10 @@ class Transcript(BaseModel):
|
|||||||
topics: list[TranscriptTopic] = []
|
topics: list[TranscriptTopic] = []
|
||||||
events: list[TranscriptEvent] = []
|
events: list[TranscriptEvent] = []
|
||||||
|
|
||||||
def add_event(self, event: str, data):
|
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||||
self.events.append(TranscriptEvent(event=event, data=data))
|
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||||
return {"event": event, "data": data}
|
self.events.append(ev)
|
||||||
|
return ev
|
||||||
|
|
||||||
def upsert_topic(self, topic: TranscriptTopic):
|
def upsert_topic(self, topic: TranscriptTopic):
|
||||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||||
@@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket)
|
|||||||
|
|
||||||
# on first connection, send all events
|
# on first connection, send all events
|
||||||
for event in transcript.events:
|
for event in transcript.events:
|
||||||
await websocket.send_json(event.model_dump())
|
await websocket.send_json(event.model_dump(mode="json"))
|
||||||
|
|
||||||
# XXX if transcript is final (locked=True and status=ended)
|
# XXX if transcript is final (locked=True and status=ended)
|
||||||
# XXX send a final event to the client and close the connection
|
# XXX send a final event to the client and close the connection
|
||||||
@@ -254,24 +263,29 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
|||||||
|
|
||||||
# FIXME don't do copy
|
# FIXME don't do copy
|
||||||
if event == PipelineEvent.TRANSCRIPT:
|
if event == PipelineEvent.TRANSCRIPT:
|
||||||
resp = transcript.add_event(event=event, data={
|
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
|
||||||
"text": data.text,
|
|
||||||
})
|
|
||||||
elif event == PipelineEvent.TOPIC:
|
elif event == PipelineEvent.TOPIC:
|
||||||
topic = TranscriptTopic(
|
topic = TranscriptTopic(
|
||||||
title=data.title,
|
title=data.title,
|
||||||
summary=data.summary,
|
summary=data.summary,
|
||||||
transcript=data.transcript,
|
transcript=data.transcript.text,
|
||||||
timestamp=data.timestamp,
|
timestamp=data.timestamp,
|
||||||
)
|
)
|
||||||
resp = transcript.add_event(event=event, data=topic.model_dump())
|
resp = transcript.add_event(event=event, data=topic)
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
|
|
||||||
|
elif event == PipelineEvent.FINAL_SUMMARY:
|
||||||
|
final_summary = TranscriptFinalSummary(summary=data.summary)
|
||||||
|
resp = transcript.add_event(event=event, data=final_summary)
|
||||||
|
transcript.summary = final_summary
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown event: {event}")
|
logger.warning(f"Unknown event: {event}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# transmit to websocket clients
|
# transmit to websocket clients
|
||||||
await ws_manager.send_json(transcript_id, resp)
|
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# FIXME try with locked session, RTC should not work
|
# FIXME try with locked session, RTC should not work
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import json
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
@@ -61,7 +62,7 @@ async def dummy_llm():
|
|||||||
|
|
||||||
class TestLLM(LLM):
|
class TestLLM(LLM):
|
||||||
async def _generate(self, prompt: str, **kwargs):
|
async def _generate(self, prompt: str, **kwargs):
|
||||||
return {"text": "LLM RESULT"}
|
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
||||||
|
|
||||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||||
mock_llm.return_value = TestLLM()
|
mock_llm.return_value = TestLLM()
|
||||||
@@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
|||||||
if timeout < 0:
|
if timeout < 0:
|
||||||
raise TimeoutError("Timeout while waiting for RTC to end")
|
raise TimeoutError("Timeout while waiting for RTC to end")
|
||||||
|
|
||||||
|
# XXX aiortc is long to close the connection
|
||||||
|
# instead of waiting a long time, we just send a STOP
|
||||||
|
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||||
|
|
||||||
|
# wait the processing to finish
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
await client.stop()
|
await client.stop()
|
||||||
|
|
||||||
# wait the processing to finish
|
# wait the processing to finish
|
||||||
@@ -141,10 +149,18 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
|||||||
websocket_task.cancel()
|
websocket_task.cancel()
|
||||||
|
|
||||||
# check events
|
# check events
|
||||||
print(events)
|
|
||||||
assert len(events) > 0
|
assert len(events) > 0
|
||||||
assert events[0]["event"] == "TRANSCRIPT"
|
assert events[0]["event"] == "TRANSCRIPT"
|
||||||
assert events[0]["data"]["text"] == "Hello world"
|
assert events[0]["data"]["text"] == "Hello world"
|
||||||
|
|
||||||
|
assert events[-2]["event"] == "TOPIC"
|
||||||
|
assert events[-2]["data"]["id"]
|
||||||
|
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
|
||||||
|
assert events[-2]["data"]["transcript"].startswith("Hello world")
|
||||||
|
assert events[-2]["data"]["timestamp"] == 0.0
|
||||||
|
|
||||||
|
assert events[-1]["event"] == "FINAL_SUMMARY"
|
||||||
|
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
|
||||||
|
|
||||||
# stop server
|
# stop server
|
||||||
# server.stop()
|
# server.stop()
|
||||||
|
|||||||
Reference in New Issue
Block a user