mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
server/rtc: fix topic output
This commit is contained in:
@@ -55,9 +55,13 @@ class LLM:
|
|||||||
|
|
||||||
regex = r"```(json|javascript|)?(.*)```"
|
regex = r"```(json|javascript|)?(.*)```"
|
||||||
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
||||||
if not matches:
|
if matches:
|
||||||
return result
|
|
||||||
|
|
||||||
# we have a match, try to parse it
|
|
||||||
result = matches[0][1]
|
result = matches[0][1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# maybe the prompt has been started with ```json
|
||||||
|
# so if text ends with ```, just remove it and use it as json
|
||||||
|
if result.endswith("```"):
|
||||||
|
result = result[:-3]
|
||||||
|
|
||||||
return json.loads(result.strip())
|
return json.loads(result.strip())
|
||||||
|
|||||||
@@ -21,9 +21,10 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
{input_text}
|
{input_text}
|
||||||
|
|
||||||
### Assistant:
|
### Assistant:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, min_transcript_length=25, **kwargs):
|
def __init__(self, min_transcript_length=100, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.transcript = None
|
self.transcript = None
|
||||||
self.min_transcript_length = min_transcript_length
|
self.min_transcript_length = min_transcript_length
|
||||||
@@ -43,6 +44,12 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
return
|
return
|
||||||
prompt = self.PROMPT.format(input_text=self.transcript.text)
|
prompt = self.PROMPT.format(input_text=self.transcript.text)
|
||||||
result = await self.llm.generate(prompt=prompt)
|
result = await self.llm.generate(prompt=prompt)
|
||||||
summary = TitleSummary(title=result["title"], summary=result["summary"])
|
summary = TitleSummary(
|
||||||
|
title=result["title"],
|
||||||
|
summary=result["summary"],
|
||||||
|
timestamp=self.transcript.timestamp,
|
||||||
|
duration=self.transcript.duration,
|
||||||
|
transcript=self.transcript,
|
||||||
|
)
|
||||||
self.transcript = None
|
self.transcript = None
|
||||||
await self.emit(summary)
|
await self.emit(summary)
|
||||||
|
|||||||
@@ -21,12 +21,6 @@ class Word:
|
|||||||
end: float
|
end: float
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TitleSummary:
|
|
||||||
title: str
|
|
||||||
summary: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Transcript:
|
class Transcript:
|
||||||
text: str = ""
|
text: str = ""
|
||||||
@@ -63,3 +57,12 @@ class Transcript:
|
|||||||
Word(text=word.text, start=word.start, end=word.end) for word in self.words
|
Word(text=word.text, start=word.start, end=word.end) for word in self.words
|
||||||
]
|
]
|
||||||
return Transcript(text=self.text, words=words)
|
return Transcript(text=self.text, words=words)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TitleSummary:
|
||||||
|
title: str
|
||||||
|
summary: str
|
||||||
|
timestamp: float
|
||||||
|
duration: float
|
||||||
|
transcript: Transcript
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
from fastapi import Request, APIRouter
|
from fastapi import Request, APIRouter
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from reflector.models import TranscriptionContext, TranscriptionOutput
|
from reflector.models import (
|
||||||
|
TranscriptionContext,
|
||||||
|
TranscriptionOutput,
|
||||||
|
TitleSummaryOutput,
|
||||||
|
IncrementalResult,
|
||||||
|
)
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||||
from json import loads, dumps
|
from json import loads, dumps
|
||||||
@@ -67,6 +72,31 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
|||||||
|
|
||||||
async def on_summary(summary: TitleSummary):
|
async def on_summary(summary: TitleSummary):
|
||||||
ctx.logger.info("Summary", summary=summary)
|
ctx.logger.info("Summary", summary=summary)
|
||||||
|
# XXX doesnt work as expected, IncrementalResult is not serializable
|
||||||
|
# and previous implementation assume output of oobagooda
|
||||||
|
# result = TitleSummaryOutput(
|
||||||
|
# [
|
||||||
|
# IncrementalResult(
|
||||||
|
# title=summary.title,
|
||||||
|
# desc=summary.summary,
|
||||||
|
# transcript=summary.transcript.text,
|
||||||
|
# timestamp=summary.timestamp,
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
result = {
|
||||||
|
"cmd": "UPDATE_TOPICS",
|
||||||
|
"topics": [
|
||||||
|
{
|
||||||
|
"title": summary.title,
|
||||||
|
"timestamp": summary.timestamp,
|
||||||
|
"transcript": summary.transcript.text,
|
||||||
|
"desc": summary.summary,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
|
|||||||
Reference in New Issue
Block a user