mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 05:09:05 +00:00
fastapi: implement server with same back compatibility as before
This commit is contained in:
15
server/reflector/app.py
Normal file
15
server/reflector/app.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
|
||||
# build app
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# register views
|
||||
app.include_router(rtc_offer_router)
|
||||
@@ -80,8 +80,7 @@ class Processor:
|
||||
if callback:
|
||||
self.on(callback)
|
||||
self.uid = uuid4().hex
|
||||
self.logger = (custom_logger or logger).bind(
|
||||
processor=self.__class__.__name__)
|
||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
self.logger = self.logger.bind(pipeline=pipeline.uid)
|
||||
@@ -107,6 +106,9 @@ class Processor:
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
# ensure callback is asynchronous
|
||||
if not asyncio.iscoroutinefunction(callback):
|
||||
raise ValueError("Callback must be a coroutine function")
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def off(self, callback):
|
||||
@@ -127,7 +129,10 @@ class Processor:
|
||||
The function returns the output of type `OUTPUT_TYPE`
|
||||
"""
|
||||
# logger.debug(f"{self.__class__.__name__} push")
|
||||
return await self._push(data)
|
||||
try:
|
||||
return await self._push(data)
|
||||
except Exception:
|
||||
self.logger.exception("Error in push")
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
@@ -463,7 +468,6 @@ class Pipeline(Processor):
|
||||
logger.info("")
|
||||
|
||||
|
||||
|
||||
class FinalSummaryProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
|
||||
121
server/reflector/views/rtc_offer.py
Normal file
121
server/reflector/views/rtc_offer.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from fastapi import Request, APIRouter
|
||||
from pydantic import BaseModel
|
||||
from reflector.models import TranscriptionContext, TranscriptionOutput
|
||||
from reflector.logger import logger
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||
from json import loads, dumps
|
||||
import av
|
||||
from reflector.processors import (
|
||||
Pipeline,
|
||||
AudioChunkerProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioAutoTranscriptProcessor,
|
||||
TranscriptLineProcessor,
|
||||
TitleSummaryProcessor,
|
||||
# FinalSummaryProcessor,
|
||||
Transcript,
|
||||
TitleSummary,
|
||||
)
|
||||
|
||||
|
||||
class AudioStreamTrack(MediaStreamTrack):
|
||||
"""
|
||||
An audio stream track.
|
||||
"""
|
||||
|
||||
kind = "audio"
|
||||
|
||||
def __init__(self, ctx: TranscriptionContext, track):
|
||||
super().__init__()
|
||||
self.ctx = ctx
|
||||
self.track = track
|
||||
|
||||
async def recv(self) -> av.audio.frame.AudioFrame:
|
||||
ctx = self.ctx
|
||||
frame = await self.track.recv()
|
||||
try:
|
||||
await ctx.pipeline.push(frame)
|
||||
except Exception as e:
|
||||
ctx.logger.error("Pipeline error", error=e)
|
||||
return frame
|
||||
|
||||
|
||||
class RtcOffer(BaseModel):
|
||||
sdp: str
|
||||
type: str
|
||||
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/offer")
|
||||
async def rtc_offer(params: RtcOffer, request: Request):
|
||||
# build an rtc session
|
||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||
|
||||
# client identification
|
||||
peername = request.client
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
cmd = TranscriptionOutput(transcript.text)
|
||||
# FIXME: send the result to the client async way
|
||||
ctx.data_channel.send(dumps(cmd.get_result()))
|
||||
|
||||
async def on_summary(summary: TitleSummary):
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
ctx.pipeline = Pipeline(
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioAutoTranscriptProcessor.as_threaded(),
|
||||
TranscriptLineProcessor(callback=on_transcript),
|
||||
TitleSummaryProcessor.as_threaded(callback=on_summary),
|
||||
# FinalSummaryProcessor.as_threaded(
|
||||
# filename=result_fn, callback=on_final_summary
|
||||
# ),
|
||||
)
|
||||
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel):
|
||||
ctx.data_channel = channel
|
||||
ctx.logger = ctx.logger.bind(channel=channel.label)
|
||||
ctx.logger.info("Channel created by remote party")
|
||||
|
||||
@channel.on("message")
|
||||
def on_message(message: str):
|
||||
ctx.logger.info(f"Message: {message}")
|
||||
if loads(message)["cmd"] == "STOP":
|
||||
# FIXME: flush the pipeline
|
||||
pass
|
||||
|
||||
if isinstance(message, str) and message.startswith("ping"):
|
||||
channel.send("pong" + message[4:])
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
ctx.logger.info(f"Connection state: {pc.connectionState}")
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
ctx.logger.info(f"Track {track.kind} received")
|
||||
pc.addTrack(AudioStreamTrack(ctx, track))
|
||||
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
sessions.append(pc)
|
||||
|
||||
return RtcOffer(sdp=pc.localDescription.sdp, type=pc.localDescription.type)
|
||||
Reference in New Issue
Block a user