server: refactor to separate websocket management + start pipeline runner

This commit is contained in:
2023-10-25 19:50:27 +02:00
committed by Mathieu Virbel
parent a45b30ee70
commit 433c0500cc
3 changed files with 148 additions and 52 deletions

View File

@@ -2,6 +2,7 @@ import asyncio
from enum import StrEnum from enum import StrEnum
from json import dumps, loads from json import dumps, loads
from pathlib import Path from pathlib import Path
from typing import Callable
import av import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
@@ -38,7 +39,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions")
class TranscriptionContext(object): class TranscriptionContext(object):
def __init__(self, logger): def __init__(self, logger):
self.logger = logger self.logger = logger
self.pipeline = None self.pipeline_runner = None
self.data_channel = None self.data_channel = None
self.status = "idle" self.status = "idle"
self.topics = [] self.topics = []
@@ -60,7 +61,7 @@ class AudioStreamTrack(MediaStreamTrack):
ctx = self.ctx ctx = self.ctx
frame = await self.track.recv() frame = await self.track.recv()
try: try:
await ctx.pipeline.push(frame) await ctx.pipeline_runner.push(frame)
except Exception as e: except Exception as e:
ctx.logger.error("Pipeline error", error=e) ctx.logger.error("Pipeline error", error=e)
return frame return frame
@@ -84,6 +85,113 @@ class PipelineEvent(StrEnum):
FINAL_TITLE = "FINAL_TITLE" FINAL_TITLE = "FINAL_TITLE"
class PipelineOptions(BaseModel):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
on_transcript: Callable | None = None
on_topic: Callable | None = None
on_final_title: Callable | None = None
on_final_short_summary: Callable | None = None
on_final_long_summary: Callable | None = None
class PipelineRunner(object):
"""
Pipeline runner designed to be executed in a asyncio task
"""
def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
self.pipeline = pipeline
self.q_cmd = asyncio.Queue()
self.ev_done = asyncio.Event()
self.status = "idle"
self.status_callback = status_callback
async def update_status(self, status):
print("update_status", status)
self.status = status
if self.status_callback:
try:
await self.status_callback(status)
except Exception as e:
logger.error("PipelineRunner status_callback error", error=e)
async def add_cmd(self, cmd: str, data):
await self.q_cmd.put([cmd, data])
async def push(self, data):
await self.add_cmd("PUSH", data)
async def flush(self):
await self.add_cmd("FLUSH", None)
async def run(self):
try:
await self.update_status("running")
while not self.ev_done.is_set():
cmd, data = await self.q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception as e:
await self.update_status("error")
logger.error("PipelineRunner error", error=e)
async def cmd_push(self, data):
if self.status == "idle":
await self.update_status("recording")
await self.pipeline.push(data)
async def cmd_flush(self, data):
await self.update_status("processing")
await self.pipeline.flush()
await self.update_status("ended")
self.ev_done.set()
def start(self):
print("start task")
asyncio.get_event_loop().create_task(self.run())
async def pipeline_live_create(options: PipelineOptions):
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
if options.audio_filename is not None:
processors += [AudioFileWriterProcessor(path=options.audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(
callback=options.on_final_title
),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=options.on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=options.on_final_short_summary
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = options
pipeline.set_pref("audio:source_language", options.source_language)
pipeline.set_pref("audio:target_language", options.target_language)
return pipeline
async def rtc_offer_base( async def rtc_offer_base(
params: RtcOffer, params: RtcOffer,
request: Request, request: Request,
@@ -211,37 +319,24 @@ async def rtc_offer_base(
data=title, data=title,
) )
# create a context for the whole rtc transaction
# add a customised logger to the context
processors = []
if audio_filename is not None:
processors += [AudioFileWriterProcessor(path=audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=on_final_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=on_final_short_summary
),
]
),
]
ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)
ctx.pipeline.set_pref("audio:target_language", target_language)
# handle RTC peer connection # handle RTC peer connection
pc = RTCPeerConnection() pc = RTCPeerConnection()
# create pipeline
options = PipelineOptions(
audio_filename=audio_filename,
source_language=source_language,
target_language=target_language,
on_transcript=on_transcript,
on_topic=on_topic,
on_final_short_summary=on_final_short_summary,
on_final_long_summary=on_final_long_summary,
on_final_title=on_final_title,
)
pipeline = await pipeline_live_create(options)
ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
ctx.pipeline_runner.start()
async def flush_pipeline_and_quit(close=True): async def flush_pipeline_and_quit(close=True):
# may be called twice # may be called twice
# 1. either the client ask to sotp the meeting # 1. either the client ask to sotp the meeting
@@ -249,12 +344,10 @@ async def rtc_offer_base(
# - when we receive the close event, we do nothing. # - when we receive the close event, we do nothing.
# 2. or the client close the connection # 2. or the client close the connection
# and there is nothing to do because it is already closed # and there is nothing to do because it is already closed
await update_status("processing") await ctx.pipeline_runner.flush()
await ctx.pipeline.flush()
if close: if close:
ctx.logger.debug("Closing peer connection") ctx.logger.debug("Closing peer connection")
await pc.close() await pc.close()
await update_status("ended")
if pc in sessions: if pc in sessions:
sessions.remove(pc) sessions.remove(pc)
m_rtc_sessions.dec() m_rtc_sessions.dec()
@@ -287,7 +380,6 @@ async def rtc_offer_base(
def on_track(track): def on_track(track):
ctx.logger.info(f"Track {track.kind} received") ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track)) pc.addTrack(AudioStreamTrack(ctx, track))
asyncio.get_event_loop().create_task(update_status("recording"))
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)

View File

@@ -483,16 +483,16 @@ async def transcript_get_topics(
] ]
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
# ============================================================== # ==============================================================
# Websocket # Websocket
# ============================================================== # ==============================================================
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
@router.websocket("/transcripts/{transcript_id}/events") @router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket( async def transcript_events_websocket(
transcript_id: str, transcript_id: str,
@@ -512,6 +512,13 @@ async def transcript_events_websocket(
try: try:
# on first connection, send all events only to the current user # on first connection, send all events only to the current user
for event in transcript.events: for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name == PipelineEvent.TRANSCRIPT:
continue
if name == PipelineEvent.STATUS:
continue
await websocket.send_json(event.model_dump(mode="json")) await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended) # XXX if transcript is final (locked=True and status=ended)

View File

@@ -33,6 +33,8 @@ class RedisPubSubManager:
) )
async def connect(self) -> None: async def connect(self) -> None:
if self.redis_connection is not None:
return
self.redis_connection = await self.get_redis_connection() self.redis_connection = await self.get_redis_connection()
self.pubsub = self.redis_connection.pubsub() self.pubsub = self.redis_connection.pubsub()
@@ -43,6 +45,8 @@ class RedisPubSubManager:
self.redis_connection = None self.redis_connection = None
async def send_json(self, room_id: str, message: str) -> None: async def send_json(self, room_id: str, message: str) -> None:
if not self.redis_connection:
await self.connect()
message = json.dumps(message) message = json.dumps(message)
await self.redis_connection.publish(room_id, message) await self.redis_connection.publish(room_id, message)
@@ -94,18 +98,6 @@ class WebsocketManager:
await socket.send_json(data) await socket.send_json(data)
def get_pubsub_client() -> RedisPubSubManager:
"""
Returns the RedisPubSubManager instance for managing Redis pubsub.
"""
from reflector.settings import settings
return RedisPubSubManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
)
def get_ws_manager() -> WebsocketManager: def get_ws_manager() -> WebsocketManager:
""" """
Returns the WebsocketManager instance for managing websockets. Returns the WebsocketManager instance for managing websockets.
@@ -122,6 +114,11 @@ def get_ws_manager() -> WebsocketManager:
RedisConnectionError: If there is an error connecting to the Redis server. RedisConnectionError: If there is an error connecting to the Redis server.
""" """
global ws_manager global ws_manager
pubsub_client = get_pubsub_client() from reflector.settings import settings
pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
)
ws_manager = WebsocketManager(pubsub_client=pubsub_client) ws_manager = WebsocketManager(pubsub_client=pubsub_client)
return ws_manager return ws_manager