mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: refactor to separate websocket management + start pipeline runner
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from json import dumps, loads
|
from json import dumps, loads
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import av
|
import av
|
||||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
@@ -38,7 +39,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions")
|
|||||||
class TranscriptionContext(object):
|
class TranscriptionContext(object):
|
||||||
def __init__(self, logger):
|
def __init__(self, logger):
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.pipeline = None
|
self.pipeline_runner = None
|
||||||
self.data_channel = None
|
self.data_channel = None
|
||||||
self.status = "idle"
|
self.status = "idle"
|
||||||
self.topics = []
|
self.topics = []
|
||||||
@@ -60,7 +61,7 @@ class AudioStreamTrack(MediaStreamTrack):
|
|||||||
ctx = self.ctx
|
ctx = self.ctx
|
||||||
frame = await self.track.recv()
|
frame = await self.track.recv()
|
||||||
try:
|
try:
|
||||||
await ctx.pipeline.push(frame)
|
await ctx.pipeline_runner.push(frame)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.logger.error("Pipeline error", error=e)
|
ctx.logger.error("Pipeline error", error=e)
|
||||||
return frame
|
return frame
|
||||||
@@ -84,6 +85,113 @@ class PipelineEvent(StrEnum):
|
|||||||
FINAL_TITLE = "FINAL_TITLE"
|
FINAL_TITLE = "FINAL_TITLE"
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineOptions(BaseModel):
|
||||||
|
audio_filename: Path | None = None
|
||||||
|
source_language: str = "en"
|
||||||
|
target_language: str = "en"
|
||||||
|
|
||||||
|
on_transcript: Callable | None = None
|
||||||
|
on_topic: Callable | None = None
|
||||||
|
on_final_title: Callable | None = None
|
||||||
|
on_final_short_summary: Callable | None = None
|
||||||
|
on_final_long_summary: Callable | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunner(object):
|
||||||
|
"""
|
||||||
|
Pipeline runner designed to be executed in a asyncio task
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
|
||||||
|
self.pipeline = pipeline
|
||||||
|
self.q_cmd = asyncio.Queue()
|
||||||
|
self.ev_done = asyncio.Event()
|
||||||
|
self.status = "idle"
|
||||||
|
self.status_callback = status_callback
|
||||||
|
|
||||||
|
async def update_status(self, status):
|
||||||
|
print("update_status", status)
|
||||||
|
self.status = status
|
||||||
|
if self.status_callback:
|
||||||
|
try:
|
||||||
|
await self.status_callback(status)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("PipelineRunner status_callback error", error=e)
|
||||||
|
|
||||||
|
async def add_cmd(self, cmd: str, data):
|
||||||
|
await self.q_cmd.put([cmd, data])
|
||||||
|
|
||||||
|
async def push(self, data):
|
||||||
|
await self.add_cmd("PUSH", data)
|
||||||
|
|
||||||
|
async def flush(self):
|
||||||
|
await self.add_cmd("FLUSH", None)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
try:
|
||||||
|
await self.update_status("running")
|
||||||
|
while not self.ev_done.is_set():
|
||||||
|
cmd, data = await self.q_cmd.get()
|
||||||
|
func = getattr(self, f"cmd_{cmd.lower()}")
|
||||||
|
if func:
|
||||||
|
await func(data)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown command {cmd}")
|
||||||
|
except Exception as e:
|
||||||
|
await self.update_status("error")
|
||||||
|
logger.error("PipelineRunner error", error=e)
|
||||||
|
|
||||||
|
async def cmd_push(self, data):
|
||||||
|
if self.status == "idle":
|
||||||
|
await self.update_status("recording")
|
||||||
|
await self.pipeline.push(data)
|
||||||
|
|
||||||
|
async def cmd_flush(self, data):
|
||||||
|
await self.update_status("processing")
|
||||||
|
await self.pipeline.flush()
|
||||||
|
await self.update_status("ended")
|
||||||
|
self.ev_done.set()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
print("start task")
|
||||||
|
asyncio.get_event_loop().create_task(self.run())
|
||||||
|
|
||||||
|
|
||||||
|
async def pipeline_live_create(options: PipelineOptions):
|
||||||
|
# create a context for the whole rtc transaction
|
||||||
|
# add a customised logger to the context
|
||||||
|
processors = []
|
||||||
|
if options.audio_filename is not None:
|
||||||
|
processors += [AudioFileWriterProcessor(path=options.audio_filename)]
|
||||||
|
processors += [
|
||||||
|
AudioChunkerProcessor(),
|
||||||
|
AudioMergeProcessor(),
|
||||||
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
|
TranscriptLinerProcessor(),
|
||||||
|
TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
|
||||||
|
TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
|
||||||
|
BroadcastProcessor(
|
||||||
|
processors=[
|
||||||
|
TranscriptFinalTitleProcessor.as_threaded(
|
||||||
|
callback=options.on_final_title
|
||||||
|
),
|
||||||
|
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||||
|
callback=options.on_final_long_summary
|
||||||
|
),
|
||||||
|
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||||
|
callback=options.on_final_short_summary
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
pipeline = Pipeline(*processors)
|
||||||
|
pipeline.options = options
|
||||||
|
pipeline.set_pref("audio:source_language", options.source_language)
|
||||||
|
pipeline.set_pref("audio:target_language", options.target_language)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
async def rtc_offer_base(
|
async def rtc_offer_base(
|
||||||
params: RtcOffer,
|
params: RtcOffer,
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -211,37 +319,24 @@ async def rtc_offer_base(
|
|||||||
data=title,
|
data=title,
|
||||||
)
|
)
|
||||||
|
|
||||||
# create a context for the whole rtc transaction
|
|
||||||
# add a customised logger to the context
|
|
||||||
processors = []
|
|
||||||
if audio_filename is not None:
|
|
||||||
processors += [AudioFileWriterProcessor(path=audio_filename)]
|
|
||||||
processors += [
|
|
||||||
AudioChunkerProcessor(),
|
|
||||||
AudioMergeProcessor(),
|
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
|
||||||
TranscriptLinerProcessor(),
|
|
||||||
TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
|
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
|
||||||
BroadcastProcessor(
|
|
||||||
processors=[
|
|
||||||
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
|
|
||||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
|
||||||
callback=on_final_long_summary
|
|
||||||
),
|
|
||||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
|
||||||
callback=on_final_short_summary
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
ctx.pipeline = Pipeline(*processors)
|
|
||||||
ctx.pipeline.set_pref("audio:source_language", source_language)
|
|
||||||
ctx.pipeline.set_pref("audio:target_language", target_language)
|
|
||||||
|
|
||||||
# handle RTC peer connection
|
# handle RTC peer connection
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
|
|
||||||
|
# create pipeline
|
||||||
|
options = PipelineOptions(
|
||||||
|
audio_filename=audio_filename,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
on_transcript=on_transcript,
|
||||||
|
on_topic=on_topic,
|
||||||
|
on_final_short_summary=on_final_short_summary,
|
||||||
|
on_final_long_summary=on_final_long_summary,
|
||||||
|
on_final_title=on_final_title,
|
||||||
|
)
|
||||||
|
pipeline = await pipeline_live_create(options)
|
||||||
|
ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
|
||||||
|
ctx.pipeline_runner.start()
|
||||||
|
|
||||||
async def flush_pipeline_and_quit(close=True):
|
async def flush_pipeline_and_quit(close=True):
|
||||||
# may be called twice
|
# may be called twice
|
||||||
# 1. either the client ask to sotp the meeting
|
# 1. either the client ask to sotp the meeting
|
||||||
@@ -249,12 +344,10 @@ async def rtc_offer_base(
|
|||||||
# - when we receive the close event, we do nothing.
|
# - when we receive the close event, we do nothing.
|
||||||
# 2. or the client close the connection
|
# 2. or the client close the connection
|
||||||
# and there is nothing to do because it is already closed
|
# and there is nothing to do because it is already closed
|
||||||
await update_status("processing")
|
await ctx.pipeline_runner.flush()
|
||||||
await ctx.pipeline.flush()
|
|
||||||
if close:
|
if close:
|
||||||
ctx.logger.debug("Closing peer connection")
|
ctx.logger.debug("Closing peer connection")
|
||||||
await pc.close()
|
await pc.close()
|
||||||
await update_status("ended")
|
|
||||||
if pc in sessions:
|
if pc in sessions:
|
||||||
sessions.remove(pc)
|
sessions.remove(pc)
|
||||||
m_rtc_sessions.dec()
|
m_rtc_sessions.dec()
|
||||||
@@ -287,7 +380,6 @@ async def rtc_offer_base(
|
|||||||
def on_track(track):
|
def on_track(track):
|
||||||
ctx.logger.info(f"Track {track.kind} received")
|
ctx.logger.info(f"Track {track.kind} received")
|
||||||
pc.addTrack(AudioStreamTrack(ctx, track))
|
pc.addTrack(AudioStreamTrack(ctx, track))
|
||||||
asyncio.get_event_loop().create_task(update_status("recording"))
|
|
||||||
|
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
|
|||||||
@@ -483,16 +483,16 @@ async def transcript_get_topics(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/events")
|
|
||||||
async def transcript_get_websocket_events(transcript_id: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
# Websocket
|
# Websocket
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/transcripts/{transcript_id}/events")
|
||||||
|
async def transcript_get_websocket_events(transcript_id: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/transcripts/{transcript_id}/events")
|
@router.websocket("/transcripts/{transcript_id}/events")
|
||||||
async def transcript_events_websocket(
|
async def transcript_events_websocket(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
@@ -512,6 +512,13 @@ async def transcript_events_websocket(
|
|||||||
try:
|
try:
|
||||||
# on first connection, send all events only to the current user
|
# on first connection, send all events only to the current user
|
||||||
for event in transcript.events:
|
for event in transcript.events:
|
||||||
|
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
||||||
|
# not necessary to be sent to the client; but keep the rest
|
||||||
|
name = event.event
|
||||||
|
if name == PipelineEvent.TRANSCRIPT:
|
||||||
|
continue
|
||||||
|
if name == PipelineEvent.STATUS:
|
||||||
|
continue
|
||||||
await websocket.send_json(event.model_dump(mode="json"))
|
await websocket.send_json(event.model_dump(mode="json"))
|
||||||
|
|
||||||
# XXX if transcript is final (locked=True and status=ended)
|
# XXX if transcript is final (locked=True and status=ended)
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ class RedisPubSubManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
|
if self.redis_connection is not None:
|
||||||
|
return
|
||||||
self.redis_connection = await self.get_redis_connection()
|
self.redis_connection = await self.get_redis_connection()
|
||||||
self.pubsub = self.redis_connection.pubsub()
|
self.pubsub = self.redis_connection.pubsub()
|
||||||
|
|
||||||
@@ -43,6 +45,8 @@ class RedisPubSubManager:
|
|||||||
self.redis_connection = None
|
self.redis_connection = None
|
||||||
|
|
||||||
async def send_json(self, room_id: str, message: str) -> None:
|
async def send_json(self, room_id: str, message: str) -> None:
|
||||||
|
if not self.redis_connection:
|
||||||
|
await self.connect()
|
||||||
message = json.dumps(message)
|
message = json.dumps(message)
|
||||||
await self.redis_connection.publish(room_id, message)
|
await self.redis_connection.publish(room_id, message)
|
||||||
|
|
||||||
@@ -94,18 +98,6 @@ class WebsocketManager:
|
|||||||
await socket.send_json(data)
|
await socket.send_json(data)
|
||||||
|
|
||||||
|
|
||||||
def get_pubsub_client() -> RedisPubSubManager:
|
|
||||||
"""
|
|
||||||
Returns the RedisPubSubManager instance for managing Redis pubsub.
|
|
||||||
"""
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
return RedisPubSubManager(
|
|
||||||
host=settings.REDIS_HOST,
|
|
||||||
port=settings.REDIS_PORT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ws_manager() -> WebsocketManager:
|
def get_ws_manager() -> WebsocketManager:
|
||||||
"""
|
"""
|
||||||
Returns the WebsocketManager instance for managing websockets.
|
Returns the WebsocketManager instance for managing websockets.
|
||||||
@@ -122,6 +114,11 @@ def get_ws_manager() -> WebsocketManager:
|
|||||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||||
"""
|
"""
|
||||||
global ws_manager
|
global ws_manager
|
||||||
pubsub_client = get_pubsub_client()
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
pubsub_client = RedisPubSubManager(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
)
|
||||||
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||||
return ws_manager
|
return ws_manager
|
||||||
|
|||||||
Reference in New Issue
Block a user