diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 5662d989..48d804cc 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -2,6 +2,7 @@ import asyncio from enum import StrEnum from json import dumps, loads from pathlib import Path +from typing import Callable import av from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription @@ -38,7 +39,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions") class TranscriptionContext(object): def __init__(self, logger): self.logger = logger - self.pipeline = None + self.pipeline_runner = None self.data_channel = None self.status = "idle" self.topics = [] @@ -60,7 +61,7 @@ class AudioStreamTrack(MediaStreamTrack): ctx = self.ctx frame = await self.track.recv() try: - await ctx.pipeline.push(frame) + await ctx.pipeline_runner.push(frame) except Exception as e: ctx.logger.error("Pipeline error", error=e) return frame @@ -84,6 +85,113 @@ class PipelineEvent(StrEnum): FINAL_TITLE = "FINAL_TITLE" +class PipelineOptions(BaseModel): + audio_filename: Path | None = None + source_language: str = "en" + target_language: str = "en" + + on_transcript: Callable | None = None + on_topic: Callable | None = None + on_final_title: Callable | None = None + on_final_short_summary: Callable | None = None + on_final_long_summary: Callable | None = None + + +class PipelineRunner(object): + """ + Pipeline runner designed to be executed in a asyncio task + """ + + def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None): + self.pipeline = pipeline + self.q_cmd = asyncio.Queue() + self.ev_done = asyncio.Event() + self.status = "idle" + self.status_callback = status_callback + + async def update_status(self, status): + print("update_status", status) + self.status = status + if self.status_callback: + try: + await self.status_callback(status) + except Exception as e: + logger.error("PipelineRunner status_callback error", error=e) + + async def add_cmd(self, cmd: str, data): + await self.q_cmd.put([cmd, data]) + + async def push(self, data): + await self.add_cmd("PUSH", data) + + async def flush(self): + await self.add_cmd("FLUSH", None) + + async def run(self): + try: + await self.update_status("running") + while not self.ev_done.is_set(): + cmd, data = await self.q_cmd.get() + func = getattr(self, f"cmd_{cmd.lower()}") + if func: + await func(data) + else: + raise Exception(f"Unknown command {cmd}") + except Exception as e: + await self.update_status("error") + logger.error("PipelineRunner error", error=e) + + async def cmd_push(self, data): + if self.status == "idle": + await self.update_status("recording") + await self.pipeline.push(data) + + async def cmd_flush(self, data): + await self.update_status("processing") + await self.pipeline.flush() + await self.update_status("ended") + self.ev_done.set() + + def start(self): + print("start task") + asyncio.get_event_loop().create_task(self.run()) + + +async def pipeline_live_create(options: PipelineOptions): + # create a context for the whole rtc transaction + # add a customised logger to the context + processors = [] + if options.audio_filename is not None: + processors += [AudioFileWriterProcessor(path=options.audio_filename)] + processors += [ + AudioChunkerProcessor(), + AudioMergeProcessor(), + AudioTranscriptAutoProcessor.as_threaded(), + TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript), + TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded( + callback=options.on_final_title + ), + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=options.on_final_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=options.on_final_short_summary + ), + ] + ), + ] + pipeline = Pipeline(*processors) + pipeline.options = options + pipeline.set_pref("audio:source_language", options.source_language) + pipeline.set_pref("audio:target_language", options.target_language) + + return pipeline + + async def rtc_offer_base( params: RtcOffer, request: Request, @@ -211,37 +319,24 @@ async def rtc_offer_base( data=title, ) - # create a context for the whole rtc transaction - # add a customised logger to the context - processors = [] - if audio_filename is not None: - processors += [AudioFileWriterProcessor(path=audio_filename)] - processors += [ - AudioChunkerProcessor(), - AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), - TranscriptLinerProcessor(), - TranscriptTranslatorProcessor.as_threaded(callback=on_transcript), - TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title), - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=on_final_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=on_final_short_summary - ), - ] - ), - ] - ctx.pipeline = Pipeline(*processors) - ctx.pipeline.set_pref("audio:source_language", source_language) - ctx.pipeline.set_pref("audio:target_language", target_language) - # handle RTC peer connection pc = RTCPeerConnection() + # create pipeline + options = PipelineOptions( + audio_filename=audio_filename, + source_language=source_language, + target_language=target_language, + on_transcript=on_transcript, + on_topic=on_topic, + on_final_short_summary=on_final_short_summary, + on_final_long_summary=on_final_long_summary, + on_final_title=on_final_title, + ) + pipeline = await pipeline_live_create(options) + ctx.pipeline_runner = PipelineRunner(pipeline, update_status) + ctx.pipeline_runner.start() + async def flush_pipeline_and_quit(close=True): # may be called twice # 1. either the client ask to sotp the meeting @@ -249,12 +344,10 @@ async def rtc_offer_base( # - when we receive the close event, we do nothing. # 2. or the client close the connection # and there is nothing to do because it is already closed - await update_status("processing") - await ctx.pipeline.flush() + await ctx.pipeline_runner.flush() if close: ctx.logger.debug("Closing peer connection") await pc.close() - await update_status("ended") if pc in sessions: sessions.remove(pc) m_rtc_sessions.dec() @@ -287,7 +380,6 @@ async def rtc_offer_base( def on_track(track): ctx.logger.info(f"Track {track.kind} received") pc.addTrack(AudioStreamTrack(ctx, track)) - asyncio.get_event_loop().create_task(update_status("recording")) await pc.setRemoteDescription(offer) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 9480461f..9f02eb6d 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -483,16 +483,16 @@ async def transcript_get_topics( ] -@router.get("/transcripts/{transcript_id}/events") -async def transcript_get_websocket_events(transcript_id: str): - pass - - # ============================================================== # Websocket # ============================================================== +@router.get("/transcripts/{transcript_id}/events") +async def transcript_get_websocket_events(transcript_id: str): + pass + + @router.websocket("/transcripts/{transcript_id}/events") async def transcript_events_websocket( transcript_id: str, @@ -512,6 +512,13 @@ async def transcript_events_websocket( try: # on first connection, send all events only to the current user for event in transcript.events: + # for now, do not send TRANSCRIPT or STATUS options - theses are live event + # not necessary to be sent to the client; but keep the rest + name = event.event + if name == PipelineEvent.TRANSCRIPT: + continue + if name == PipelineEvent.STATUS: + continue await websocket.send_json(event.model_dump(mode="json")) # XXX if transcript is final (locked=True and status=ended) diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index 43475c1d..1dfe9e3d 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -33,6 +33,8 @@ class RedisPubSubManager: ) async def connect(self) -> None: + if self.redis_connection is not None: + return self.redis_connection = await self.get_redis_connection() self.pubsub = self.redis_connection.pubsub() @@ -43,6 +45,8 @@ class RedisPubSubManager: self.redis_connection = None async def send_json(self, room_id: str, message: str) -> None: + if not self.redis_connection: + await self.connect() message = json.dumps(message) await self.redis_connection.publish(room_id, message) @@ -94,18 +98,6 @@ class WebsocketManager: await socket.send_json(data) -def get_pubsub_client() -> RedisPubSubManager: - """ - Returns the RedisPubSubManager instance for managing Redis pubsub. - """ - from reflector.settings import settings - - return RedisPubSubManager( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - ) - - def get_ws_manager() -> WebsocketManager: """ Returns the WebsocketManager instance for managing websockets. @@ -122,6 +114,11 @@ def get_ws_manager() -> WebsocketManager: RedisConnectionError: If there is an error connecting to the Redis server. """ global ws_manager - pubsub_client = get_pubsub_client() + from reflector.settings import settings + + pubsub_client = RedisPubSubManager( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + ) ws_manager = WebsocketManager(pubsub_client=pubsub_client) return ws_manager