server: implement status update in model and websocket

This commit is contained in:
Mathieu Virbel
2023-08-09 11:21:48 +02:00
parent 7f807c8f5f
commit a9e0c9aa03
4 changed files with 52 additions and 9 deletions

View File

@@ -199,6 +199,7 @@ class TranscriptionContext:
sorted_transcripts: dict
data_channel: None # FIXME
logger: None
status: str
def __init__(self, logger):
self.transcription_text = ""
@@ -206,4 +207,5 @@ class TranscriptionContext:
self.incremental_responses = []
self.data_channel = None
self.sorted_transcripts = SortedDict()
self.status = "idle"
self.logger = logger

View File

@@ -52,10 +52,15 @@ class RtcOffer(BaseModel):
type: str
class StrValue(BaseModel):
value: str
class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_SUMMARY = "FINAL_SUMMARY"
STATUS = "STATUS"
async def rtc_offer_base(
@@ -70,6 +75,17 @@ async def rtc_offer_base(
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
ctx.topics = []
async def update_status(status: str):
changed = ctx.status != status
if changed:
ctx.status = status
if event_callback:
await event_callback(
event=PipelineEvent.STATUS,
args=event_callback_args,
data=StrValue(value=status),
)
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
@@ -148,10 +164,12 @@ async def rtc_offer_base(
pc = RTCPeerConnection()
async def flush_pipeline_and_quit(close=True):
await update_status("processing")
await ctx.pipeline.flush()
if close:
ctx.logger.debug("Closing peer connection")
await pc.close()
await update_status("ended")
@pc.on("datachannel")
def on_datachannel(channel):
@@ -181,6 +199,7 @@ async def rtc_offer_base(
def on_track(track):
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track))
asyncio.get_event_loop().create_task(update_status("recording"))
await pc.setRemoteDescription(offer)

View File

@@ -280,6 +280,9 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
resp = transcript.add_event(event=event, data=final_summary)
transcript.summary = final_summary
elif event == PipelineEvent.STATUS:
resp = transcript.add_event(event=event, data=data)
else:
logger.warning(f"Unknown event: {event}")
return

View File

@@ -150,17 +150,36 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
# check events
assert len(events) > 0
assert events[0]["event"] == "TRANSCRIPT"
assert events[0]["data"]["text"] == "Hello world"
from pprint import pprint
assert events[-2]["event"] == "TOPIC"
assert events[-2]["data"]["id"]
assert events[-2]["data"]["summary"] == "LLM SUMMARY"
assert events[-2]["data"]["transcript"].startswith("Hello world")
assert events[-2]["data"]["timestamp"] == 0.0
pprint(events)
assert events[-1]["event"] == "FINAL_SUMMARY"
assert events[-1]["data"]["summary"] == "LLM SUMMARY"
# get events list
eventnames = [e["event"] for e in events]
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# stop server
# server.stop()