mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server/rtc: fix topic output
This commit is contained in:
@@ -55,9 +55,13 @@ class LLM:
|
||||
|
||||
regex = r"```(json|javascript|)?(.*)```"
|
||||
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
||||
if not matches:
|
||||
return result
|
||||
if matches:
|
||||
result = matches[0][1]
|
||||
|
||||
else:
|
||||
# maybe the prompt has been started with ```json
|
||||
# so if text ends with ```, just remove it and use it as json
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
|
||||
# we have a match, try to parse it
|
||||
result = matches[0][1]
|
||||
return json.loads(result.strip())
|
||||
|
||||
@@ -21,9 +21,10 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
{input_text}
|
||||
|
||||
### Assistant:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, min_transcript_length=25, **kwargs):
|
||||
def __init__(self, min_transcript_length=100, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
@@ -43,6 +44,12 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
return
|
||||
prompt = self.PROMPT.format(input_text=self.transcript.text)
|
||||
result = await self.llm.generate(prompt=prompt)
|
||||
summary = TitleSummary(title=result["title"], summary=result["summary"])
|
||||
summary = TitleSummary(
|
||||
title=result["title"],
|
||||
summary=result["summary"],
|
||||
timestamp=self.transcript.timestamp,
|
||||
duration=self.transcript.duration,
|
||||
transcript=self.transcript,
|
||||
)
|
||||
self.transcript = None
|
||||
await self.emit(summary)
|
||||
|
||||
@@ -21,12 +21,6 @@ class Word:
|
||||
end: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummary:
|
||||
title: str
|
||||
summary: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transcript:
|
||||
text: str = ""
|
||||
@@ -63,3 +57,12 @@ class Transcript:
|
||||
Word(text=word.text, start=word.start, end=word.end) for word in self.words
|
||||
]
|
||||
return Transcript(text=self.text, words=words)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummary:
|
||||
title: str
|
||||
summary: str
|
||||
timestamp: float
|
||||
duration: float
|
||||
transcript: Transcript
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from fastapi import Request, APIRouter
|
||||
from pydantic import BaseModel
|
||||
from reflector.models import TranscriptionContext, TranscriptionOutput
|
||||
from reflector.models import (
|
||||
TranscriptionContext,
|
||||
TranscriptionOutput,
|
||||
TitleSummaryOutput,
|
||||
IncrementalResult,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||
from json import loads, dumps
|
||||
@@ -67,6 +72,31 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
||||
|
||||
async def on_summary(summary: TitleSummary):
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
# XXX doesnt work as expected, IncrementalResult is not serializable
|
||||
# and previous implementation assume output of oobagooda
|
||||
# result = TitleSummaryOutput(
|
||||
# [
|
||||
# IncrementalResult(
|
||||
# title=summary.title,
|
||||
# desc=summary.summary,
|
||||
# transcript=summary.transcript.text,
|
||||
# timestamp=summary.timestamp,
|
||||
# )
|
||||
# ]
|
||||
# )
|
||||
result = {
|
||||
"cmd": "UPDATE_TOPICS",
|
||||
"topics": [
|
||||
{
|
||||
"title": summary.title,
|
||||
"timestamp": summary.timestamp,
|
||||
"transcript": summary.transcript.text,
|
||||
"desc": summary.summary,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
|
||||
Reference in New Issue
Block a user