server/rtc: fix topic output

This commit is contained in:
Mathieu Virbel
2023-08-01 19:12:51 +02:00
parent 99c9ba3e6b
commit d320558cc9
4 changed files with 57 additions and 13 deletions

View File

@@ -55,9 +55,13 @@ class LLM:
regex = r"```(json|javascript|)?(.*)```" regex = r"```(json|javascript|)?(.*)```"
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL) matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
if not matches: if matches:
return result
# we have a match, try to parse it
result = matches[0][1] result = matches[0][1]
else:
# maybe the prompt has been started with ```json
# so if text ends with ```, just remove it and use it as json
if result.endswith("```"):
result = result[:-3]
return json.loads(result.strip()) return json.loads(result.strip())

View File

@@ -21,9 +21,10 @@ class TranscriptTopicDetectorProcessor(Processor):
{input_text} {input_text}
### Assistant: ### Assistant:
""" """
def __init__(self, min_transcript_length=25, **kwargs): def __init__(self, min_transcript_length=100, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.transcript = None self.transcript = None
self.min_transcript_length = min_transcript_length self.min_transcript_length = min_transcript_length
@@ -43,6 +44,12 @@ class TranscriptTopicDetectorProcessor(Processor):
return return
prompt = self.PROMPT.format(input_text=self.transcript.text) prompt = self.PROMPT.format(input_text=self.transcript.text)
result = await self.llm.generate(prompt=prompt) result = await self.llm.generate(prompt=prompt)
summary = TitleSummary(title=result["title"], summary=result["summary"]) summary = TitleSummary(
title=result["title"],
summary=result["summary"],
timestamp=self.transcript.timestamp,
duration=self.transcript.duration,
transcript=self.transcript,
)
self.transcript = None self.transcript = None
await self.emit(summary) await self.emit(summary)

View File

@@ -21,12 +21,6 @@ class Word:
end: float end: float
@dataclass
class TitleSummary:
title: str
summary: str
@dataclass @dataclass
class Transcript: class Transcript:
text: str = "" text: str = ""
@@ -63,3 +57,12 @@ class Transcript:
Word(text=word.text, start=word.start, end=word.end) for word in self.words Word(text=word.text, start=word.start, end=word.end) for word in self.words
] ]
return Transcript(text=self.text, words=words) return Transcript(text=self.text, words=words)
@dataclass
class TitleSummary:
title: str
summary: str
timestamp: float
duration: float
transcript: Transcript

View File

@@ -1,6 +1,11 @@
from fastapi import Request, APIRouter from fastapi import Request, APIRouter
from pydantic import BaseModel from pydantic import BaseModel
from reflector.models import TranscriptionContext, TranscriptionOutput from reflector.models import (
TranscriptionContext,
TranscriptionOutput,
TitleSummaryOutput,
IncrementalResult,
)
from reflector.logger import logger from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from json import loads, dumps from json import loads, dumps
@@ -67,6 +72,31 @@ async def rtc_offer(params: RtcOffer, request: Request):
async def on_summary(summary: TitleSummary): async def on_summary(summary: TitleSummary):
ctx.logger.info("Summary", summary=summary) ctx.logger.info("Summary", summary=summary)
# XXX doesnt work as expected, IncrementalResult is not serializable
# and previous implementation assume output of oobagooda
# result = TitleSummaryOutput(
# [
# IncrementalResult(
# title=summary.title,
# desc=summary.summary,
# transcript=summary.transcript.text,
# timestamp=summary.timestamp,
# )
# ]
# )
result = {
"cmd": "UPDATE_TOPICS",
"topics": [
{
"title": summary.title,
"timestamp": summary.timestamp,
"transcript": summary.transcript.text,
"desc": summary.summary,
}
],
}
ctx.data_channel.send(dumps(result))
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context