From d320558cc9bfcdcb392b51b6cc53b51337dce3c8 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 1 Aug 2023 19:12:51 +0200 Subject: [PATCH] server/rtc: fix topic output --- server/reflector/llm/base.py | 12 ++++--- .../processors/transcript_topic_detector.py | 11 +++++-- server/reflector/processors/types.py | 15 +++++---- server/reflector/views/rtc_offer.py | 32 ++++++++++++++++++- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 6b5dbdc9..031e38aa 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -55,9 +55,13 @@ class LLM: regex = r"```(json|javascript|)?(.*)```" matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL) - if not matches: - return result + if matches: + result = matches[0][1] + + else: + # maybe the prompt has been started with ```json + # so if text ends with ```, just remove it and use it as json + if result.endswith("```"): + result = result[:-3] - # we have a match, try to parse it - result = matches[0][1] return json.loads(result.strip()) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 31a88882..b602d61e 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -21,9 +21,10 @@ class TranscriptTopicDetectorProcessor(Processor): {input_text} ### Assistant: + """ - def __init__(self, min_transcript_length=25, **kwargs): + def __init__(self, min_transcript_length=100, **kwargs): super().__init__(**kwargs) self.transcript = None self.min_transcript_length = min_transcript_length @@ -43,6 +44,12 @@ class TranscriptTopicDetectorProcessor(Processor): return prompt = self.PROMPT.format(input_text=self.transcript.text) result = await self.llm.generate(prompt=prompt) - summary = TitleSummary(title=result["title"], summary=result["summary"]) + summary = TitleSummary( + title=result["title"], + summary=result["summary"], + timestamp=self.transcript.timestamp, + duration=self.transcript.duration, + transcript=self.transcript, + ) self.transcript = None await self.emit(summary) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 1a89c127..c4c840dd 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -21,12 +21,6 @@ class Word: end: float -@dataclass -class TitleSummary: - title: str - summary: str - - @dataclass class Transcript: text: str = "" @@ -63,3 +57,12 @@ class Transcript: Word(text=word.text, start=word.start, end=word.end) for word in self.words ] return Transcript(text=self.text, words=words) + + +@dataclass +class TitleSummary: + title: str + summary: str + timestamp: float + duration: float + transcript: Transcript diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index f462a37a..77007035 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,6 +1,11 @@ from fastapi import Request, APIRouter from pydantic import BaseModel -from reflector.models import TranscriptionContext, TranscriptionOutput +from reflector.models import ( + TranscriptionContext, + TranscriptionOutput, + TitleSummaryOutput, + IncrementalResult, +) from reflector.logger import logger from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from json import loads, dumps @@ -67,6 +72,31 @@ async def rtc_offer(params: RtcOffer, request: Request): async def on_summary(summary: TitleSummary): ctx.logger.info("Summary", summary=summary) + # XXX doesnt work as expected, IncrementalResult is not serializable + # and previous implementation assume output of oobagooda + # result = TitleSummaryOutput( + # [ + # IncrementalResult( + # title=summary.title, + # desc=summary.summary, + # transcript=summary.transcript.text, + # timestamp=summary.timestamp, + # ) + # ] + # ) + result = { + "cmd": "UPDATE_TOPICS", + "topics": [ + { + "title": summary.title, + "timestamp": summary.timestamp, + "transcript": summary.transcript.text, + "desc": summary.summary, + } + ], + } + + ctx.data_channel.send(dumps(result)) # create a context for the whole rtc transaction # add a customised logger to the context