mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
server: implement status update in model and websocket
This commit is contained in:
@@ -199,6 +199,7 @@ class TranscriptionContext:
|
|||||||
sorted_transcripts: dict
|
sorted_transcripts: dict
|
||||||
data_channel: None # FIXME
|
data_channel: None # FIXME
|
||||||
logger: None
|
logger: None
|
||||||
|
status: str
|
||||||
|
|
||||||
def __init__(self, logger):
|
def __init__(self, logger):
|
||||||
self.transcription_text = ""
|
self.transcription_text = ""
|
||||||
@@ -206,4 +207,5 @@ class TranscriptionContext:
|
|||||||
self.incremental_responses = []
|
self.incremental_responses = []
|
||||||
self.data_channel = None
|
self.data_channel = None
|
||||||
self.sorted_transcripts = SortedDict()
|
self.sorted_transcripts = SortedDict()
|
||||||
|
self.status = "idle"
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|||||||
@@ -52,10 +52,15 @@ class RtcOffer(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class StrValue(BaseModel):
|
||||||
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class PipelineEvent(StrEnum):
|
class PipelineEvent(StrEnum):
|
||||||
TRANSCRIPT = "TRANSCRIPT"
|
TRANSCRIPT = "TRANSCRIPT"
|
||||||
TOPIC = "TOPIC"
|
TOPIC = "TOPIC"
|
||||||
FINAL_SUMMARY = "FINAL_SUMMARY"
|
FINAL_SUMMARY = "FINAL_SUMMARY"
|
||||||
|
STATUS = "STATUS"
|
||||||
|
|
||||||
|
|
||||||
async def rtc_offer_base(
|
async def rtc_offer_base(
|
||||||
@@ -70,6 +75,17 @@ async def rtc_offer_base(
|
|||||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||||
ctx.topics = []
|
ctx.topics = []
|
||||||
|
|
||||||
|
async def update_status(status: str):
|
||||||
|
changed = ctx.status != status
|
||||||
|
if changed:
|
||||||
|
ctx.status = status
|
||||||
|
if event_callback:
|
||||||
|
await event_callback(
|
||||||
|
event=PipelineEvent.STATUS,
|
||||||
|
args=event_callback_args,
|
||||||
|
data=StrValue(value=status),
|
||||||
|
)
|
||||||
|
|
||||||
# build pipeline callback
|
# build pipeline callback
|
||||||
async def on_transcript(transcript: Transcript):
|
async def on_transcript(transcript: Transcript):
|
||||||
ctx.logger.info("Transcript", transcript=transcript)
|
ctx.logger.info("Transcript", transcript=transcript)
|
||||||
@@ -148,10 +164,12 @@ async def rtc_offer_base(
|
|||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
|
|
||||||
async def flush_pipeline_and_quit(close=True):
|
async def flush_pipeline_and_quit(close=True):
|
||||||
|
await update_status("processing")
|
||||||
await ctx.pipeline.flush()
|
await ctx.pipeline.flush()
|
||||||
if close:
|
if close:
|
||||||
ctx.logger.debug("Closing peer connection")
|
ctx.logger.debug("Closing peer connection")
|
||||||
await pc.close()
|
await pc.close()
|
||||||
|
await update_status("ended")
|
||||||
|
|
||||||
@pc.on("datachannel")
|
@pc.on("datachannel")
|
||||||
def on_datachannel(channel):
|
def on_datachannel(channel):
|
||||||
@@ -181,6 +199,7 @@ async def rtc_offer_base(
|
|||||||
def on_track(track):
|
def on_track(track):
|
||||||
ctx.logger.info(f"Track {track.kind} received")
|
ctx.logger.info(f"Track {track.kind} received")
|
||||||
pc.addTrack(AudioStreamTrack(ctx, track))
|
pc.addTrack(AudioStreamTrack(ctx, track))
|
||||||
|
asyncio.get_event_loop().create_task(update_status("recording"))
|
||||||
|
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
|
|||||||
@@ -280,6 +280,9 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
|||||||
resp = transcript.add_event(event=event, data=final_summary)
|
resp = transcript.add_event(event=event, data=final_summary)
|
||||||
transcript.summary = final_summary
|
transcript.summary = final_summary
|
||||||
|
|
||||||
|
elif event == PipelineEvent.STATUS:
|
||||||
|
resp = transcript.add_event(event=event, data=data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown event: {event}")
|
logger.warning(f"Unknown event: {event}")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -150,17 +150,36 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
|||||||
|
|
||||||
# check events
|
# check events
|
||||||
assert len(events) > 0
|
assert len(events) > 0
|
||||||
assert events[0]["event"] == "TRANSCRIPT"
|
from pprint import pprint
|
||||||
assert events[0]["data"]["text"] == "Hello world"
|
|
||||||
|
|
||||||
assert events[-2]["event"] == "TOPIC"
|
pprint(events)
|
||||||
assert events[-2]["data"]["id"]
|
|
||||||
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
|
|
||||||
assert events[-2]["data"]["transcript"].startswith("Hello world")
|
|
||||||
assert events[-2]["data"]["timestamp"] == 0.0
|
|
||||||
|
|
||||||
assert events[-1]["event"] == "FINAL_SUMMARY"
|
# get events list
|
||||||
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
|
eventnames = [e["event"] for e in events]
|
||||||
|
|
||||||
|
# check events
|
||||||
|
assert "TRANSCRIPT" in eventnames
|
||||||
|
ev = events[eventnames.index("TRANSCRIPT")]
|
||||||
|
assert ev["data"]["text"] == "Hello world"
|
||||||
|
|
||||||
|
assert "TOPIC" in eventnames
|
||||||
|
ev = events[eventnames.index("TOPIC")]
|
||||||
|
assert ev["data"]["id"]
|
||||||
|
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||||
|
assert ev["data"]["transcript"].startswith("Hello world")
|
||||||
|
assert ev["data"]["timestamp"] == 0.0
|
||||||
|
|
||||||
|
assert "FINAL_SUMMARY" in eventnames
|
||||||
|
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||||
|
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||||
|
|
||||||
|
# check status order
|
||||||
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
|
assert statuses == ["recording", "processing", "ended"]
|
||||||
|
|
||||||
|
# ensure the last event received is ended
|
||||||
|
assert events[-1]["event"] == "STATUS"
|
||||||
|
assert events[-1]["data"]["value"] == "ended"
|
||||||
|
|
||||||
# stop server
|
# stop server
|
||||||
# server.stop()
|
# server.stop()
|
||||||
|
|||||||
Reference in New Issue
Block a user