mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 13:19:05 +00:00
server: refactor with clearer pipeline instanciation and linked to model
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from json import dumps, loads
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from json import loads
|
||||
|
||||
import av
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
@@ -11,25 +9,7 @@ from prometheus_client import Gauge
|
||||
from pydantic import BaseModel
|
||||
from reflector.events import subscribers_shutdown
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
FinalLongSummary,
|
||||
FinalShortSummary,
|
||||
Pipeline,
|
||||
TitleSummary,
|
||||
Transcript,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
from reflector.processors.types import FinalTitle
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
@@ -85,121 +65,10 @@ class PipelineEvent(StrEnum):
|
||||
FINAL_TITLE = "FINAL_TITLE"
|
||||
|
||||
|
||||
class PipelineOptions(BaseModel):
|
||||
audio_filename: Path | None = None
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
on_transcript: Callable | None = None
|
||||
on_topic: Callable | None = None
|
||||
on_final_title: Callable | None = None
|
||||
on_final_short_summary: Callable | None = None
|
||||
on_final_long_summary: Callable | None = None
|
||||
|
||||
|
||||
class PipelineRunner(object):
|
||||
"""
|
||||
Pipeline runner designed to be executed in a asyncio task
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
|
||||
self.pipeline = pipeline
|
||||
self.q_cmd = asyncio.Queue()
|
||||
self.ev_done = asyncio.Event()
|
||||
self.status = "idle"
|
||||
self.status_callback = status_callback
|
||||
|
||||
async def update_status(self, status):
|
||||
print("update_status", status)
|
||||
self.status = status
|
||||
if self.status_callback:
|
||||
try:
|
||||
await self.status_callback(status)
|
||||
except Exception as e:
|
||||
logger.error("PipelineRunner status_callback error", error=e)
|
||||
|
||||
async def add_cmd(self, cmd: str, data):
|
||||
await self.q_cmd.put([cmd, data])
|
||||
|
||||
async def push(self, data):
|
||||
await self.add_cmd("PUSH", data)
|
||||
|
||||
async def flush(self):
|
||||
await self.add_cmd("FLUSH", None)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
await self.update_status("running")
|
||||
while not self.ev_done.is_set():
|
||||
cmd, data = await self.q_cmd.get()
|
||||
func = getattr(self, f"cmd_{cmd.lower()}")
|
||||
if func:
|
||||
await func(data)
|
||||
else:
|
||||
raise Exception(f"Unknown command {cmd}")
|
||||
except Exception as e:
|
||||
await self.update_status("error")
|
||||
logger.error("PipelineRunner error", error=e)
|
||||
|
||||
async def cmd_push(self, data):
|
||||
if self.status == "idle":
|
||||
await self.update_status("recording")
|
||||
await self.pipeline.push(data)
|
||||
|
||||
async def cmd_flush(self, data):
|
||||
await self.update_status("processing")
|
||||
await self.pipeline.flush()
|
||||
await self.update_status("ended")
|
||||
self.ev_done.set()
|
||||
|
||||
def start(self):
|
||||
print("start task")
|
||||
asyncio.get_event_loop().create_task(self.run())
|
||||
|
||||
|
||||
async def pipeline_live_create(options: PipelineOptions):
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
processors = []
|
||||
if options.audio_filename is not None:
|
||||
processors += [AudioFileWriterProcessor(path=options.audio_filename)]
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(
|
||||
callback=options.on_final_title
|
||||
),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=options.on_final_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=options.on_final_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = options
|
||||
pipeline.set_pref("audio:source_language", options.source_language)
|
||||
pipeline.set_pref("audio:target_language", options.target_language)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
event_callback=None,
|
||||
event_callback_args=None,
|
||||
audio_filename: Path | None = None,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
pipeline_runner: PipelineRunner,
|
||||
):
|
||||
# build an rtc session
|
||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||
@@ -209,132 +78,9 @@ async def rtc_offer_base(
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
|
||||
async def update_status(status: str):
|
||||
changed = ctx.status != status
|
||||
if changed:
|
||||
ctx.status = status
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.STATUS,
|
||||
args=event_callback_args,
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": transcript.text,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TRANSCRIPT,
|
||||
args=event_callback_args,
|
||||
data=transcript,
|
||||
)
|
||||
|
||||
async def on_topic(topic: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Topic", topic=topic)
|
||||
ctx.topics.append(
|
||||
{
|
||||
"title": topic.title,
|
||||
"timestamp": topic.timestamp,
|
||||
"transcript": topic.transcript.text,
|
||||
"desc": topic.summary,
|
||||
}
|
||||
)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
|
||||
)
|
||||
|
||||
async def on_final_short_summary(summary: FinalShortSummary):
|
||||
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
|
||||
"summary": summary.short_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_SHORT_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_long_summary(summary: FinalLongSummary):
|
||||
ctx.logger.info("FinalLongSummary", final_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
|
||||
"summary": summary.long_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_LONG_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_title(title: FinalTitle):
|
||||
ctx.logger.info("FinalTitle", final_title=title)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_TITLE,
|
||||
args=event_callback_args,
|
||||
data=title,
|
||||
)
|
||||
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# create pipeline
|
||||
options = PipelineOptions(
|
||||
audio_filename=audio_filename,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
on_transcript=on_transcript,
|
||||
on_topic=on_topic,
|
||||
on_final_short_summary=on_final_short_summary,
|
||||
on_final_long_summary=on_final_long_summary,
|
||||
on_final_title=on_final_title,
|
||||
)
|
||||
pipeline = await pipeline_live_create(options)
|
||||
ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
|
||||
ctx.pipeline_runner = pipeline_runner
|
||||
ctx.pipeline_runner.start()
|
||||
|
||||
async def flush_pipeline_and_quit(close=True):
|
||||
@@ -400,8 +146,3 @@ async def rtc_clean_sessions(_):
|
||||
logger.debug(f"Closing session {pc}")
|
||||
await pc.close()
|
||||
sessions.clear()
|
||||
|
||||
|
||||
@router.post("/offer")
|
||||
async def rtc_offer(params: RtcOffer, request: Request):
|
||||
return await rtc_offer_base(params, request)
|
||||
|
||||
Reference in New Issue
Block a user