mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
server: fixes latest implementation details on rtc offer and fastapi
This commit is contained in:
@@ -1,9 +1,22 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||||
|
from reflector.events import subscribers_startup, subscribers_shutdown
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
|
# lifespan events
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
for func in subscribers_startup:
|
||||||
|
await func()
|
||||||
|
yield
|
||||||
|
for func in subscribers_shutdown:
|
||||||
|
await func()
|
||||||
|
|
||||||
|
|
||||||
# build app
|
# build app
|
||||||
app = FastAPI()
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=["*"],
|
||||||
|
|||||||
2
server/reflector/events.py
Normal file
2
server/reflector/events.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
subscribers_startup = []
|
||||||
|
subscribers_shutdown = []
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
|
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
|
||||||
from .types import AudioFile, Transcript, Word, TitleSummary # noqa: F401
|
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
|
||||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||||
from .transcript_summarizer import TranscriptSummarizerProcessor # noqa: F401
|
|
||||||
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
||||||
|
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||||
|
|||||||
30
server/reflector/processors/transcript_final_summary.py
Normal file
30
server/reflector/processors/transcript_final_summary.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import TitleSummary, FinalSummary
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptFinalSummaryProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Assemble all summary into a line-based json
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = TitleSummary
|
||||||
|
OUTPUT_TYPE = FinalSummary
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.chunks: list[TitleSummary] = []
|
||||||
|
|
||||||
|
async def _push(self, data: TitleSummary):
|
||||||
|
self.chunks.append(data)
|
||||||
|
|
||||||
|
async def _flush(self):
|
||||||
|
if not self.chunks:
|
||||||
|
self.logger.warning("No summary to output")
|
||||||
|
return
|
||||||
|
|
||||||
|
# FIXME improve final summary
|
||||||
|
result = "\n".join([chunk.summary for chunk in self.chunks])
|
||||||
|
last_chunk = self.chunks[-1]
|
||||||
|
duration = last_chunk.timestamp + last_chunk.duration
|
||||||
|
|
||||||
|
await self.emit(FinalSummary(summary=result, duration=duration))
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
from reflector.processors.base import Processor
|
|
||||||
from reflector.processors.types import TitleSummary
|
|
||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptSummarizerProcessor(Processor):
|
|
||||||
"""
|
|
||||||
Assemble all summary into a line-based json
|
|
||||||
"""
|
|
||||||
|
|
||||||
INPUT_TYPE = TitleSummary
|
|
||||||
OUTPUT_TYPE = Path
|
|
||||||
|
|
||||||
def __init__(self, filename: Path, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.filename = filename
|
|
||||||
self.chunkcount = 0
|
|
||||||
|
|
||||||
async def _push(self, data: TitleSummary):
|
|
||||||
with open(self.filename, "a", encoding="utf8") as fd:
|
|
||||||
fd.write(json.dumps(data))
|
|
||||||
self.chunkcount += 1
|
|
||||||
|
|
||||||
async def _flush(self):
|
|
||||||
if self.chunkcount == 0:
|
|
||||||
self.logger.warning("No summary to write")
|
|
||||||
return
|
|
||||||
self.logger.info(f"Writing to {self.filename}")
|
|
||||||
await self.emit(self.filename)
|
|
||||||
@@ -66,3 +66,9 @@ class TitleSummary:
|
|||||||
timestamp: float
|
timestamp: float
|
||||||
duration: float
|
duration: float
|
||||||
transcript: Transcript
|
transcript: Transcript
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinalSummary:
|
||||||
|
summary: str
|
||||||
|
duration: float
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
|
import asyncio
|
||||||
from fastapi import Request, APIRouter
|
from fastapi import Request, APIRouter
|
||||||
|
from reflector.events import subscribers_shutdown
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from reflector.models import (
|
from reflector.models import (
|
||||||
TranscriptionContext,
|
TranscriptionContext,
|
||||||
TranscriptionOutput,
|
TranscriptionOutput,
|
||||||
TitleSummaryOutput,
|
|
||||||
IncrementalResult,
|
|
||||||
)
|
)
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||||
@@ -17,10 +17,15 @@ from reflector.processors import (
|
|||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
TranscriptLinerProcessor,
|
TranscriptLinerProcessor,
|
||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
|
TranscriptFinalSummaryProcessor,
|
||||||
Transcript,
|
Transcript,
|
||||||
TitleSummary,
|
TitleSummary,
|
||||||
|
FinalSummary,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sessions = []
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class AudioStreamTrack(MediaStreamTrack):
|
class AudioStreamTrack(MediaStreamTrack):
|
||||||
"""
|
"""
|
||||||
@@ -49,10 +54,6 @@ class RtcOffer(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
|
|
||||||
|
|
||||||
sessions = []
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/offer")
|
@router.post("/offer")
|
||||||
async def rtc_offer(params: RtcOffer, request: Request):
|
async def rtc_offer(params: RtcOffer, request: Request):
|
||||||
# build an rtc session
|
# build an rtc session
|
||||||
@@ -62,40 +63,38 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
|||||||
peername = request.client
|
peername = request.client
|
||||||
clientid = f"{peername[0]}:{peername[1]}"
|
clientid = f"{peername[0]}:{peername[1]}"
|
||||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||||
|
ctx.topics = []
|
||||||
|
|
||||||
# build pipeline callback
|
# build pipeline callback
|
||||||
async def on_transcript(transcript: Transcript):
|
async def on_transcript(transcript: Transcript):
|
||||||
ctx.logger.info("Transcript", transcript=transcript)
|
ctx.logger.info("Transcript", transcript=transcript)
|
||||||
cmd = TranscriptionOutput(transcript.text)
|
|
||||||
# FIXME: send the result to the client async way
|
|
||||||
ctx.data_channel.send(dumps(cmd.get_result()))
|
|
||||||
|
|
||||||
async def on_summary(summary: TitleSummary):
|
|
||||||
ctx.logger.info("Summary", summary=summary)
|
|
||||||
# XXX doesnt work as expected, IncrementalResult is not serializable
|
|
||||||
# and previous implementation assume output of oobagooda
|
|
||||||
# result = TitleSummaryOutput(
|
|
||||||
# [
|
|
||||||
# IncrementalResult(
|
|
||||||
# title=summary.title,
|
|
||||||
# desc=summary.summary,
|
|
||||||
# transcript=summary.transcript.text,
|
|
||||||
# timestamp=summary.timestamp,
|
|
||||||
# )
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
result = {
|
result = {
|
||||||
"cmd": "UPDATE_TOPICS",
|
"cmd": "SHOW_TRANSCRIPTION",
|
||||||
"topics": [
|
"text": transcript.text,
|
||||||
|
}
|
||||||
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
|
async def on_topic(summary: TitleSummary):
|
||||||
|
# FIXME: make it incremental with the frontend, not send everything
|
||||||
|
ctx.logger.info("Summary", summary=summary)
|
||||||
|
ctx.topics.append(
|
||||||
{
|
{
|
||||||
"title": summary.title,
|
"title": summary.title,
|
||||||
"timestamp": summary.timestamp,
|
"timestamp": summary.timestamp,
|
||||||
"transcript": summary.transcript.text,
|
"transcript": summary.transcript.text,
|
||||||
"desc": summary.summary,
|
"desc": summary.summary,
|
||||||
}
|
}
|
||||||
],
|
)
|
||||||
}
|
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||||
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
|
async def on_final_summary(summary: FinalSummary):
|
||||||
|
ctx.logger.info("FinalSummary", final_summary=summary)
|
||||||
|
result = {
|
||||||
|
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||||
|
"summary": summary.summary,
|
||||||
|
"duration": summary.duration,
|
||||||
|
}
|
||||||
ctx.data_channel.send(dumps(result))
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
@@ -105,15 +104,19 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
|||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
|
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||||
# FinalSummaryProcessor.as_threaded(
|
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||||
# filename=result_fn, callback=on_final_summary
|
|
||||||
# ),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle RTC peer connection
|
# handle RTC peer connection
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
|
|
||||||
|
async def flush_pipeline_and_quit():
|
||||||
|
ctx.logger.info("Flushing pipeline")
|
||||||
|
await ctx.pipeline.flush()
|
||||||
|
ctx.logger.debug("Closing peer connection")
|
||||||
|
await pc.close()
|
||||||
|
|
||||||
@pc.on("datachannel")
|
@pc.on("datachannel")
|
||||||
def on_datachannel(channel):
|
def on_datachannel(channel):
|
||||||
ctx.data_channel = channel
|
ctx.data_channel = channel
|
||||||
@@ -124,8 +127,8 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
|||||||
def on_message(message: str):
|
def on_message(message: str):
|
||||||
ctx.logger.info(f"Message: {message}")
|
ctx.logger.info(f"Message: {message}")
|
||||||
if loads(message)["cmd"] == "STOP":
|
if loads(message)["cmd"] == "STOP":
|
||||||
# FIXME: flush the pipeline
|
ctx.logger.debug("STOP command received")
|
||||||
pass
|
asyncio.get_event_loop().create_task(flush_pipeline_and_quit())
|
||||||
|
|
||||||
if isinstance(message, str) and message.startswith("ping"):
|
if isinstance(message, str) and message.startswith("ping"):
|
||||||
channel.send("pong" + message[4:])
|
channel.send("pong" + message[4:])
|
||||||
@@ -148,3 +151,12 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
|||||||
sessions.append(pc)
|
sessions.append(pc)
|
||||||
|
|
||||||
return RtcOffer(sdp=pc.localDescription.sdp, type=pc.localDescription.type)
|
return RtcOffer(sdp=pc.localDescription.sdp, type=pc.localDescription.type)
|
||||||
|
|
||||||
|
|
||||||
|
@subscribers_shutdown.append
|
||||||
|
async def rtc_clean_sessions():
|
||||||
|
logger.info("Closing all RTC sessions")
|
||||||
|
for pc in sessions:
|
||||||
|
logger.debug(f"Closing session {pc}")
|
||||||
|
await pc.close()
|
||||||
|
sessions.clear()
|
||||||
|
|||||||
Reference in New Issue
Block a user