diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 7d11590d..692a490b 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -14,6 +14,7 @@ class Processor: if callback: self.on(callback) self.uid = uuid4().hex + self.flushed = False self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) def set_pipeline(self, pipeline: "Pipeline"): @@ -65,6 +66,7 @@ class Processor: """ # logger.debug(f"{self.__class__.__name__} push") try: + self.flushed = False return await self._push(data) except Exception: self.logger.exception("Error in push") @@ -72,8 +74,12 @@ class Processor: async def flush(self): """ Flush data to this processor + Works only one time, until another push is called """ + if self.flushed: + return # logger.debug(f"{self.__class__.__name__} flush") + self.flushed = True return await self._flush() def describe(self, level=0): diff --git a/server/reflector/stream_client.py b/server/reflector/stream_client.py index 6b66ad45..b3e4d966 100644 --- a/server/reflector/stream_client.py +++ b/server/reflector/stream_client.py @@ -72,7 +72,7 @@ class StreamClient: async def on_connectionstatechange(): self.logger.info(f"Connection state is {pc.connectionState}") if pc.connectionState == "failed": - await pc.close() + await self.stop() self.pcs.discard(pc) @pc.on("track") @@ -87,7 +87,7 @@ class StreamClient: self.pc.addTrack(audio) self.track_audio = audio - channel = pc.createDataChannel("data-channel") + self.channel = channel = pc.createDataChannel("data-channel") self.logger = self.logger.bind(channel=channel.label) self.logger.info("Created by local party") diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 288153e6..3a8eb874 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -73,12 +73,16 @@ async def rtc_offer_base( # build pipeline callback async def on_transcript(transcript: Transcript): ctx.logger.info("Transcript", transcript=transcript) - result = { - "cmd": "SHOW_TRANSCRIPTION", - "text": transcript.text, - } - ctx.data_channel.send(dumps(result)) + # send to RTC + if ctx.data_channel.readyState == "open": + result = { + "cmd": "SHOW_TRANSCRIPTION", + "text": transcript.text, + } + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) if event_callback: await event_callback( event=PipelineEvent.TRANSCRIPT, @@ -86,9 +90,7 @@ async def rtc_offer_base( data=transcript, ) - async def on_topic( - summary: TitleSummary, event_callback=None, event_callback_args=None - ): + async def on_topic(summary: TitleSummary): # FIXME: make it incremental with the frontend, not send everything ctx.logger.info("Summary", summary=summary) ctx.topics.append( @@ -99,28 +101,36 @@ async def rtc_offer_base( "desc": summary.summary, } ) - result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} - ctx.data_channel.send(dumps(result)) + # send to RTC + if ctx.data_channel.readyState == "open": + result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) if event_callback: await event_callback( event=PipelineEvent.TOPIC, args=event_callback_args, data=summary ) - async def on_final_summary( - summary: FinalSummary, event_callback=None, event_callback_args=None - ): + async def on_final_summary(summary: FinalSummary): ctx.logger.info("FinalSummary", final_summary=summary) - result = { - "cmd": "DISPLAY_FINAL_SUMMARY", - "summary": summary.summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) + # send to RTC + if ctx.data_channel.readyState == "open": + result = { + "cmd": "DISPLAY_FINAL_SUMMARY", + "summary": summary.summary, + "duration": summary.duration, + } + ctx.data_channel.send(dumps(result)) + + # send to callback (eg. websocket) if event_callback: await event_callback( - event=PipelineEvent.TOPIC, args=event_callback_args, data=summary + event=PipelineEvent.FINAL_SUMMARY, + args=event_callback_args, + data=summary, ) # create a context for the whole rtc transaction @@ -137,11 +147,11 @@ async def rtc_offer_base( # handle RTC peer connection pc = RTCPeerConnection() - async def flush_pipeline_and_quit(): - ctx.logger.info("Flushing pipeline") + async def flush_pipeline_and_quit(close=True): await ctx.pipeline.flush() - ctx.logger.debug("Closing peer connection") - await pc.close() + if close: + ctx.logger.debug("Closing peer connection") + await pc.close() @pc.on("datachannel") def on_datachannel(channel): @@ -164,6 +174,8 @@ async def rtc_offer_base( ctx.logger.info(f"Connection state: {pc.connectionState}") if pc.connectionState == "failed": await pc.close() + elif pc.connectionState == "closed": + await flush_pipeline_and_quit(close=False) @pc.on("track") def on_track(track): diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 120c3ff1..6cd265ce 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -21,6 +21,10 @@ def generate_transcript_name(): return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +class TranscriptText(BaseModel): + text: str + + class TranscriptTopic(BaseModel): id: UUID = Field(default_factory=uuid4) title: str @@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel): timestamp: float +class TranscriptFinalSummary(BaseModel): + summary: str + + class TranscriptEvent(BaseModel): event: str data: dict @@ -45,9 +53,10 @@ class Transcript(BaseModel): topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] - def add_event(self, event: str, data): - self.events.append(TranscriptEvent(event=event, data=data)) - return {"event": event, "data": data} + def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: + ev = TranscriptEvent(event=event, data=data.model_dump()) + self.events.append(ev) + return ev def upsert_topic(self, topic: TranscriptTopic): existing_topic = next((t for t in self.topics if t.id == topic.id), None) @@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket) # on first connection, send all events for event in transcript.events: - await websocket.send_json(event.model_dump()) + await websocket.send_json(event.model_dump(mode="json")) # XXX if transcript is final (locked=True and status=ended) # XXX send a final event to the client and close the connection @@ -254,24 +263,29 @@ async def handle_rtc_event(event: PipelineEvent, args, data): # FIXME don't do copy if event == PipelineEvent.TRANSCRIPT: - resp = transcript.add_event(event=event, data={ - "text": data.text, - }) + resp = transcript.add_event(event=event, data=TranscriptText(text=data.text)) + elif event == PipelineEvent.TOPIC: topic = TranscriptTopic( title=data.title, summary=data.summary, - transcript=data.transcript, + transcript=data.transcript.text, timestamp=data.timestamp, ) - resp = transcript.add_event(event=event, data=topic.model_dump()) + resp = transcript.add_event(event=event, data=topic) transcript.upsert_topic(topic) + + elif event == PipelineEvent.FINAL_SUMMARY: + final_summary = TranscriptFinalSummary(summary=data.summary) + resp = transcript.add_event(event=event, data=final_summary) + transcript.summary = final_summary + else: logger.warning(f"Unknown event: {event}") return # transmit to websocket clients - await ws_manager.send_json(transcript_id, resp) + await ws_manager.send_json(transcript_id, resp.model_dump(mode="json")) @router.post("/transcripts/{transcript_id}/record/webrtc") diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index e6280c94..9724c7fe 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -4,6 +4,7 @@ # FIXME try with locked session, RTC should not work import pytest +import json from unittest.mock import patch from httpx import AsyncClient @@ -61,7 +62,7 @@ async def dummy_llm(): class TestLLM(LLM): async def _generate(self, prompt: str, **kwargs): - return {"text": "LLM RESULT"} + return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: mock_llm.return_value = TestLLM() @@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): if timeout < 0: raise TimeoutError("Timeout while waiting for RTC to end") + # XXX aiortc is long to close the connection + # instead of waiting a long time, we just send a STOP + client.channel.send(json.dumps({"cmd": "STOP"})) + + # wait the processing to finish + await asyncio.sleep(2) + await client.stop() # wait the processing to finish @@ -141,10 +149,18 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): websocket_task.cancel() # check events - print(events) assert len(events) > 0 assert events[0]["event"] == "TRANSCRIPT" assert events[0]["data"]["text"] == "Hello world" + assert events[-2]["event"] == "TOPIC" + assert events[-2]["data"]["id"] + assert events[-2]["data"]["summary"] == "LLM SUMMARY" + assert events[-2]["data"]["transcript"].startswith("Hello world") + assert events[-2]["data"]["timestamp"] == 0.0 + + assert events[-1]["event"] == "FINAL_SUMMARY" + assert events[-1]["data"]["summary"] == "LLM SUMMARY" + # stop server # server.stop()