server: fixes latest implementation details on rtc offer and fastapi

This commit is contained in:
Mathieu Virbel
2023-08-01 20:09:05 +02:00
parent d320558cc9
commit 74d2974ed2
7 changed files with 105 additions and 72 deletions

View File

@@ -1,10 +1,10 @@
import asyncio
from fastapi import Request, APIRouter
from reflector.events import subscribers_shutdown
from pydantic import BaseModel
from reflector.models import (
TranscriptionContext,
TranscriptionOutput,
TitleSummaryOutput,
IncrementalResult,
)
from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
@@ -17,10 +17,15 @@ from reflector.processors import (
AudioTranscriptAutoProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptFinalSummaryProcessor,
Transcript,
TitleSummary,
FinalSummary,
)
sessions = []
router = APIRouter()
class AudioStreamTrack(MediaStreamTrack):
"""
@@ -49,10 +54,6 @@ class RtcOffer(BaseModel):
type: str
sessions = []
router = APIRouter()
@router.post("/offer")
async def rtc_offer(params: RtcOffer, request: Request):
# build an rtc session
@@ -62,40 +63,38 @@ async def rtc_offer(params: RtcOffer, request: Request):
peername = request.client
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
ctx.topics = []
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
cmd = TranscriptionOutput(transcript.text)
# FIXME: send the result to the client async way
ctx.data_channel.send(dumps(cmd.get_result()))
async def on_summary(summary: TitleSummary):
ctx.logger.info("Summary", summary=summary)
# XXX doesnt work as expected, IncrementalResult is not serializable
# and previous implementation assume output of oobagooda
# result = TitleSummaryOutput(
# [
# IncrementalResult(
# title=summary.title,
# desc=summary.summary,
# transcript=summary.transcript.text,
# timestamp=summary.timestamp,
# )
# ]
# )
result = {
"cmd": "UPDATE_TOPICS",
"topics": [
{
"title": summary.title,
"timestamp": summary.timestamp,
"transcript": summary.transcript.text,
"desc": summary.summary,
}
],
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
async def on_topic(summary: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary)
ctx.topics.append(
{
"title": summary.title,
"timestamp": summary.timestamp,
"transcript": summary.transcript.text,
"desc": summary.summary,
}
)
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
async def on_final_summary(summary: FinalSummary):
ctx.logger.info("FinalSummary", final_summary=summary)
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# create a context for the whole rtc transaction
@@ -105,15 +104,19 @@ async def rtc_offer(params: RtcOffer, request: Request):
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
# FinalSummaryProcessor.as_threaded(
# filename=result_fn, callback=on_final_summary
# ),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
)
# handle RTC peer connection
pc = RTCPeerConnection()
async def flush_pipeline_and_quit():
ctx.logger.info("Flushing pipeline")
await ctx.pipeline.flush()
ctx.logger.debug("Closing peer connection")
await pc.close()
@pc.on("datachannel")
def on_datachannel(channel):
ctx.data_channel = channel
@@ -124,8 +127,8 @@ async def rtc_offer(params: RtcOffer, request: Request):
def on_message(message: str):
ctx.logger.info(f"Message: {message}")
if loads(message)["cmd"] == "STOP":
# FIXME: flush the pipeline
pass
ctx.logger.debug("STOP command received")
asyncio.get_event_loop().create_task(flush_pipeline_and_quit())
if isinstance(message, str) and message.startswith("ping"):
channel.send("pong" + message[4:])
@@ -148,3 +151,12 @@ async def rtc_offer(params: RtcOffer, request: Request):
sessions.append(pc)
return RtcOffer(sdp=pc.localDescription.sdp, type=pc.localDescription.type)
@subscribers_shutdown.append
async def rtc_clean_sessions():
logger.info("Closing all RTC sessions")
for pc in sessions:
logger.debug(f"Closing session {pc}")
await pc.close()
sessions.clear()