server: refactor with clearer pipeline instanciation and linked to model

This commit is contained in:
2023-10-26 19:00:56 +02:00
committed by Mathieu Virbel
parent 433c0500cc
commit 1c42473da0
8 changed files with 658 additions and 616 deletions

View File

@@ -1,8 +1,6 @@
import asyncio
from enum import StrEnum
from json import dumps, loads
from pathlib import Path
from typing import Callable
from json import loads
import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
@@ -11,25 +9,7 @@ from prometheus_client import Gauge
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
FinalLongSummary,
FinalShortSummary,
Pipeline,
TitleSummary,
Transcript,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import FinalTitle
from reflector.pipelines.runner import PipelineRunner
sessions = []
router = APIRouter()
@@ -85,121 +65,10 @@ class PipelineEvent(StrEnum):
FINAL_TITLE = "FINAL_TITLE"
class PipelineOptions(BaseModel):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
on_transcript: Callable | None = None
on_topic: Callable | None = None
on_final_title: Callable | None = None
on_final_short_summary: Callable | None = None
on_final_long_summary: Callable | None = None
class PipelineRunner(object):
"""
Pipeline runner designed to be executed in a asyncio task
"""
def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
self.pipeline = pipeline
self.q_cmd = asyncio.Queue()
self.ev_done = asyncio.Event()
self.status = "idle"
self.status_callback = status_callback
async def update_status(self, status):
print("update_status", status)
self.status = status
if self.status_callback:
try:
await self.status_callback(status)
except Exception as e:
logger.error("PipelineRunner status_callback error", error=e)
async def add_cmd(self, cmd: str, data):
await self.q_cmd.put([cmd, data])
async def push(self, data):
await self.add_cmd("PUSH", data)
async def flush(self):
await self.add_cmd("FLUSH", None)
async def run(self):
try:
await self.update_status("running")
while not self.ev_done.is_set():
cmd, data = await self.q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception as e:
await self.update_status("error")
logger.error("PipelineRunner error", error=e)
async def cmd_push(self, data):
if self.status == "idle":
await self.update_status("recording")
await self.pipeline.push(data)
async def cmd_flush(self, data):
await self.update_status("processing")
await self.pipeline.flush()
await self.update_status("ended")
self.ev_done.set()
def start(self):
print("start task")
asyncio.get_event_loop().create_task(self.run())
async def pipeline_live_create(options: PipelineOptions):
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
if options.audio_filename is not None:
processors += [AudioFileWriterProcessor(path=options.audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(
callback=options.on_final_title
),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=options.on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=options.on_final_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = options
pipeline.set_pref("audio:source_language", options.source_language)
pipeline.set_pref("audio:target_language", options.target_language)
return pipeline
async def rtc_offer_base(
params: RtcOffer,
request: Request,
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
source_language: str = "en",
target_language: str = "en",
pipeline_runner: PipelineRunner,
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -209,132 +78,9 @@ async def rtc_offer_base(
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
async def update_status(status: str):
changed = ctx.status != status
if changed:
ctx.status = status
if event_callback:
await event_callback(
event=PipelineEvent.STATUS,
args=event_callback_args,
data=StrValue(value=status),
)
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TRANSCRIPT,
args=event_callback_args,
data=transcript,
)
async def on_topic(topic: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Topic", topic=topic)
ctx.topics.append(
{
"title": topic.title,
"timestamp": topic.timestamp,
"transcript": topic.transcript.text,
"desc": topic.summary,
}
)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
)
async def on_final_short_summary(summary: FinalShortSummary):
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
"summary": summary.short_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_SHORT_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_long_summary(summary: FinalLongSummary):
ctx.logger.info("FinalLongSummary", final_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
"summary": summary.long_summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_LONG_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_title(title: FinalTitle):
ctx.logger.info("FinalTitle", final_title=title)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_TITLE,
args=event_callback_args,
data=title,
)
# handle RTC peer connection
pc = RTCPeerConnection()
# create pipeline
options = PipelineOptions(
audio_filename=audio_filename,
source_language=source_language,
target_language=target_language,
on_transcript=on_transcript,
on_topic=on_topic,
on_final_short_summary=on_final_short_summary,
on_final_long_summary=on_final_long_summary,
on_final_title=on_final_title,
)
pipeline = await pipeline_live_create(options)
ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
ctx.pipeline_runner = pipeline_runner
ctx.pipeline_runner.start()
async def flush_pipeline_and_quit(close=True):
@@ -400,8 +146,3 @@ async def rtc_clean_sessions(_):
logger.debug(f"Closing session {pc}")
await pc.close()
sessions.clear()
@router.post("/offer")
async def rtc_offer(params: RtcOffer, request: Request):
return await rtc_offer_base(params, request)