mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: fixes latest implementation details on rtc offer and fastapi
This commit is contained in:
@@ -1,9 +1,22 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
from reflector.events import subscribers_startup, subscribers_shutdown
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
# lifespan events
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
for func in subscribers_startup:
|
||||
await func()
|
||||
yield
|
||||
for func in subscribers_shutdown:
|
||||
await func()
|
||||
|
||||
|
||||
# build app
|
||||
app = FastAPI()
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
|
||||
2
server/reflector/events.py
Normal file
2
server/reflector/events.py
Normal file
@@ -0,0 +1,2 @@
|
||||
subscribers_startup = []
|
||||
subscribers_shutdown = []
|
||||
@@ -1,9 +1,9 @@
|
||||
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
|
||||
from .types import AudioFile, Transcript, Word, TitleSummary # noqa: F401
|
||||
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
from .transcript_summarizer import TranscriptSummarizerProcessor # noqa: F401
|
||||
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||
|
||||
30
server/reflector/processors/transcript_final_summary.py
Normal file
30
server/reflector/processors/transcript_final_summary.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, FinalSummary
|
||||
|
||||
|
||||
class TranscriptFinalSummaryProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
# FIXME improve final summary
|
||||
result = "\n".join([chunk.summary for chunk in self.chunks])
|
||||
last_chunk = self.chunks[-1]
|
||||
duration = last_chunk.timestamp + last_chunk.duration
|
||||
|
||||
await self.emit(FinalSummary(summary=result, duration=duration))
|
||||
@@ -1,30 +0,0 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
|
||||
class TranscriptSummarizerProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = Path
|
||||
|
||||
def __init__(self, filename: Path, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.filename = filename
|
||||
self.chunkcount = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
with open(self.filename, "a", encoding="utf8") as fd:
|
||||
fd.write(json.dumps(data))
|
||||
self.chunkcount += 1
|
||||
|
||||
async def _flush(self):
|
||||
if self.chunkcount == 0:
|
||||
self.logger.warning("No summary to write")
|
||||
return
|
||||
self.logger.info(f"Writing to {self.filename}")
|
||||
await self.emit(self.filename)
|
||||
@@ -66,3 +66,9 @@ class TitleSummary:
|
||||
timestamp: float
|
||||
duration: float
|
||||
transcript: Transcript
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSummary:
|
||||
summary: str
|
||||
duration: float
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
from fastapi import Request, APIRouter
|
||||
from reflector.events import subscribers_shutdown
|
||||
from pydantic import BaseModel
|
||||
from reflector.models import (
|
||||
TranscriptionContext,
|
||||
TranscriptionOutput,
|
||||
TitleSummaryOutput,
|
||||
IncrementalResult,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||
@@ -17,10 +17,15 @@ from reflector.processors import (
|
||||
AudioTranscriptAutoProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
Transcript,
|
||||
TitleSummary,
|
||||
FinalSummary,
|
||||
)
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AudioStreamTrack(MediaStreamTrack):
|
||||
"""
|
||||
@@ -49,10 +54,6 @@ class RtcOffer(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/offer")
|
||||
async def rtc_offer(params: RtcOffer, request: Request):
|
||||
# build an rtc session
|
||||
@@ -62,40 +63,38 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
||||
peername = request.client
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
ctx.topics = []
|
||||
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
cmd = TranscriptionOutput(transcript.text)
|
||||
# FIXME: send the result to the client async way
|
||||
ctx.data_channel.send(dumps(cmd.get_result()))
|
||||
|
||||
async def on_summary(summary: TitleSummary):
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
# XXX doesnt work as expected, IncrementalResult is not serializable
|
||||
# and previous implementation assume output of oobagooda
|
||||
# result = TitleSummaryOutput(
|
||||
# [
|
||||
# IncrementalResult(
|
||||
# title=summary.title,
|
||||
# desc=summary.summary,
|
||||
# transcript=summary.transcript.text,
|
||||
# timestamp=summary.timestamp,
|
||||
# )
|
||||
# ]
|
||||
# )
|
||||
result = {
|
||||
"cmd": "UPDATE_TOPICS",
|
||||
"topics": [
|
||||
{
|
||||
"title": summary.title,
|
||||
"timestamp": summary.timestamp,
|
||||
"transcript": summary.transcript.text,
|
||||
"desc": summary.summary,
|
||||
}
|
||||
],
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": transcript.text,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
async def on_topic(summary: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
ctx.topics.append(
|
||||
{
|
||||
"title": summary.title,
|
||||
"timestamp": summary.timestamp,
|
||||
"transcript": summary.transcript.text,
|
||||
"desc": summary.summary,
|
||||
}
|
||||
)
|
||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
async def on_final_summary(summary: FinalSummary):
|
||||
ctx.logger.info("FinalSummary", final_summary=summary)
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||
"summary": summary.summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
@@ -105,15 +104,19 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
|
||||
# FinalSummaryProcessor.as_threaded(
|
||||
# filename=result_fn, callback=on_final_summary
|
||||
# ),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||
)
|
||||
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
async def flush_pipeline_and_quit():
|
||||
ctx.logger.info("Flushing pipeline")
|
||||
await ctx.pipeline.flush()
|
||||
ctx.logger.debug("Closing peer connection")
|
||||
await pc.close()
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel):
|
||||
ctx.data_channel = channel
|
||||
@@ -124,8 +127,8 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
||||
def on_message(message: str):
|
||||
ctx.logger.info(f"Message: {message}")
|
||||
if loads(message)["cmd"] == "STOP":
|
||||
# FIXME: flush the pipeline
|
||||
pass
|
||||
ctx.logger.debug("STOP command received")
|
||||
asyncio.get_event_loop().create_task(flush_pipeline_and_quit())
|
||||
|
||||
if isinstance(message, str) and message.startswith("ping"):
|
||||
channel.send("pong" + message[4:])
|
||||
@@ -148,3 +151,12 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
||||
sessions.append(pc)
|
||||
|
||||
return RtcOffer(sdp=pc.localDescription.sdp, type=pc.localDescription.type)
|
||||
|
||||
|
||||
@subscribers_shutdown.append
|
||||
async def rtc_clean_sessions():
|
||||
logger.info("Closing all RTC sessions")
|
||||
for pc in sessions:
|
||||
logger.debug(f"Closing session {pc}")
|
||||
await pc.close()
|
||||
sessions.clear()
|
||||
|
||||
Reference in New Issue
Block a user