fastapi: implement server with same back compatibility as before

This commit is contained in:
2023-07-29 15:59:25 +02:00
parent 3908c1ca53
commit 224afc6f28
5 changed files with 419 additions and 16 deletions

15
server/reflector/app.py Normal file
View File

@@ -0,0 +1,15 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from reflector.views.rtc_offer import router as rtc_offer_router
# build app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# register views
app.include_router(rtc_offer_router)

View File

@@ -80,8 +80,7 @@ class Processor:
if callback:
self.on(callback)
self.uid = uuid4().hex
self.logger = (custom_logger or logger).bind(
processor=self.__class__.__name__)
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
def set_pipeline(self, pipeline: "Pipeline"):
self.logger = self.logger.bind(pipeline=pipeline.uid)
@@ -107,6 +106,9 @@ class Processor:
"""
Register a callback to be called when data is emitted
"""
# ensure callback is asynchronous
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callback must be a coroutine function")
self._callbacks.append(callback)
def off(self, callback):
@@ -127,7 +129,10 @@ class Processor:
The function returns the output of type `OUTPUT_TYPE`
"""
# logger.debug(f"{self.__class__.__name__} push")
return await self._push(data)
try:
return await self._push(data)
except Exception:
self.logger.exception("Error in push")
async def flush(self):
"""
@@ -463,7 +468,6 @@ class Pipeline(Processor):
logger.info("")
class FinalSummaryProcessor(Processor):
"""
Assemble all summary into a line-based json

View File

@@ -0,0 +1,121 @@
from fastapi import Request, APIRouter
from pydantic import BaseModel
from reflector.models import TranscriptionContext, TranscriptionOutput
from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from json import loads, dumps
import av
from reflector.processors import (
Pipeline,
AudioChunkerProcessor,
AudioMergeProcessor,
AudioAutoTranscriptProcessor,
TranscriptLineProcessor,
TitleSummaryProcessor,
# FinalSummaryProcessor,
Transcript,
TitleSummary,
)
class AudioStreamTrack(MediaStreamTrack):
"""
An audio stream track.
"""
kind = "audio"
def __init__(self, ctx: TranscriptionContext, track):
super().__init__()
self.ctx = ctx
self.track = track
async def recv(self) -> av.audio.frame.AudioFrame:
ctx = self.ctx
frame = await self.track.recv()
try:
await ctx.pipeline.push(frame)
except Exception as e:
ctx.logger.error("Pipeline error", error=e)
return frame
class RtcOffer(BaseModel):
sdp: str
type: str
sessions = []
router = APIRouter()
@router.post("/offer")
async def rtc_offer(params: RtcOffer, request: Request):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
# client identification
peername = request.client
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
cmd = TranscriptionOutput(transcript.text)
# FIXME: send the result to the client async way
ctx.data_channel.send(dumps(cmd.get_result()))
async def on_summary(summary: TitleSummary):
ctx.logger.info("Summary", summary=summary)
# create a context for the whole rtc transaction
# add a customised logger to the context
ctx.pipeline = Pipeline(
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioAutoTranscriptProcessor.as_threaded(),
TranscriptLineProcessor(callback=on_transcript),
TitleSummaryProcessor.as_threaded(callback=on_summary),
# FinalSummaryProcessor.as_threaded(
# filename=result_fn, callback=on_final_summary
# ),
)
# handle RTC peer connection
pc = RTCPeerConnection()
@pc.on("datachannel")
def on_datachannel(channel):
ctx.data_channel = channel
ctx.logger = ctx.logger.bind(channel=channel.label)
ctx.logger.info("Channel created by remote party")
@channel.on("message")
def on_message(message: str):
ctx.logger.info(f"Message: {message}")
if loads(message)["cmd"] == "STOP":
# FIXME: flush the pipeline
pass
if isinstance(message, str) and message.startswith("ping"):
channel.send("pong" + message[4:])
@pc.on("connectionstatechange")
async def on_connectionstatechange():
ctx.logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
@pc.on("track")
def on_track(track):
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track))
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
sessions.append(pc)
return RtcOffer(sdp=pc.localDescription.sdp, type=pc.localDescription.type)