mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: implement status update in model and websocket
This commit is contained in:
@@ -199,6 +199,7 @@ class TranscriptionContext:
|
||||
sorted_transcripts: dict
|
||||
data_channel: None # FIXME
|
||||
logger: None
|
||||
status: str
|
||||
|
||||
def __init__(self, logger):
|
||||
self.transcription_text = ""
|
||||
@@ -206,4 +207,5 @@ class TranscriptionContext:
|
||||
self.incremental_responses = []
|
||||
self.data_channel = None
|
||||
self.sorted_transcripts = SortedDict()
|
||||
self.status = "idle"
|
||||
self.logger = logger
|
||||
|
||||
@@ -52,10 +52,15 @@ class RtcOffer(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineEvent(StrEnum):
|
||||
TRANSCRIPT = "TRANSCRIPT"
|
||||
TOPIC = "TOPIC"
|
||||
FINAL_SUMMARY = "FINAL_SUMMARY"
|
||||
STATUS = "STATUS"
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
@@ -70,6 +75,17 @@ async def rtc_offer_base(
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
ctx.topics = []
|
||||
|
||||
async def update_status(status: str):
|
||||
changed = ctx.status != status
|
||||
if changed:
|
||||
ctx.status = status
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.STATUS,
|
||||
args=event_callback_args,
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
@@ -148,10 +164,12 @@ async def rtc_offer_base(
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
async def flush_pipeline_and_quit(close=True):
|
||||
await update_status("processing")
|
||||
await ctx.pipeline.flush()
|
||||
if close:
|
||||
ctx.logger.debug("Closing peer connection")
|
||||
await pc.close()
|
||||
await update_status("ended")
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel):
|
||||
@@ -181,6 +199,7 @@ async def rtc_offer_base(
|
||||
def on_track(track):
|
||||
ctx.logger.info(f"Track {track.kind} received")
|
||||
pc.addTrack(AudioStreamTrack(ctx, track))
|
||||
asyncio.get_event_loop().create_task(update_status("recording"))
|
||||
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
|
||||
@@ -280,6 +280,9 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
resp = transcript.add_event(event=event, data=final_summary)
|
||||
transcript.summary = final_summary
|
||||
|
||||
elif event == PipelineEvent.STATUS:
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
return
|
||||
|
||||
@@ -150,17 +150,36 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
|
||||
# check events
|
||||
assert len(events) > 0
|
||||
assert events[0]["event"] == "TRANSCRIPT"
|
||||
assert events[0]["data"]["text"] == "Hello world"
|
||||
from pprint import pprint
|
||||
|
||||
assert events[-2]["event"] == "TOPIC"
|
||||
assert events[-2]["data"]["id"]
|
||||
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
|
||||
assert events[-2]["data"]["transcript"].startswith("Hello world")
|
||||
assert events[-2]["data"]["timestamp"] == 0.0
|
||||
pprint(events)
|
||||
|
||||
assert events[-1]["event"] == "FINAL_SUMMARY"
|
||||
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
|
||||
# get events list
|
||||
eventnames = [e["event"] for e in events]
|
||||
|
||||
# check events
|
||||
assert "TRANSCRIPT" in eventnames
|
||||
ev = events[eventnames.index("TRANSCRIPT")]
|
||||
assert ev["data"]["text"] == "Hello world"
|
||||
|
||||
assert "TOPIC" in eventnames
|
||||
ev = events[eventnames.index("TOPIC")]
|
||||
assert ev["data"]["id"]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
assert ev["data"]["transcript"].startswith("Hello world")
|
||||
assert ev["data"]["timestamp"] == 0.0
|
||||
|
||||
assert "FINAL_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses == ["recording", "processing", "ended"]
|
||||
|
||||
# ensure the last event received is ended
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
assert events[-1]["data"]["value"] == "ended"
|
||||
|
||||
# stop server
|
||||
# server.stop()
|
||||
|
||||
Reference in New Issue
Block a user