diff --git a/server/docker-compose.yml b/server/docker-compose.yml index 374130fa..4e5a21e8 100644 --- a/server/docker-compose.yml +++ b/server/docker-compose.yml @@ -1,15 +1,19 @@ version: "3.9" services: - server: - build: - context: . + # server: + # build: + # context: . + # ports: + # - 1250:1250 + # environment: + # LLM_URL: "${LLM_URL}" + # MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}" + # volumes: + # - model-cache:/root/.cache + redis: + image: redis:7.2 ports: - - 1250:1250 - environment: - LLM_URL: "${LLM_URL}" - MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}" - volumes: - - model-cache:/root/.cache + - 6379:6379 volumes: model-cache: diff --git a/server/poetry.lock b/server/poetry.lock index 0df46097..35d98382 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2919,6 +2919,24 @@ files = [ [package.extras] full = ["numpy"] +[[package]] +name = "redis" +version = "5.0.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, + {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "regex" version = "2023.10.3" @@ -4046,4 +4064,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c" +content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63" diff --git a/server/pyproject.toml b/server/pyproject.toml index 7b1b7936..ed231a4f 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -34,6 +34,7 @@ sentencepiece = "^0.1.99" protobuf = "^4.24.3" profanityfilter = "^2.0.6" celery = "^5.3.4" +redis = "^5.0.1" [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 1503948a..d7cc2c33 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -117,5 +117,9 @@ class Settings(BaseSettings): CELERY_BROKER_URL: str = "redis://localhost:6379/1" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" + # Redis + REDIS_HOST: str = "localhost" + REDIS_PORT: int = 6379 + settings = Settings() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 1d9fd4bd..9480461f 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -21,12 +21,14 @@ from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform +from reflector.ws_manager import get_ws_manager from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() +ws_manager = get_ws_manager() # ============================================================== # Models to move to a database, but required for the API to work @@ -487,40 +489,10 @@ async def transcript_get_websocket_events(transcript_id: str): # ============================================================== -# Websocket Manager +# Websocket # ============================================================== -class WebsocketManager: - def __init__(self): - self.active_connections = {} - - async def connect(self, transcript_id: str, websocket: WebSocket): - await websocket.accept() - if transcript_id not in self.active_connections: - self.active_connections[transcript_id] = [] - self.active_connections[transcript_id].append(websocket) - - def disconnect(self, transcript_id: str, websocket: WebSocket): - if transcript_id not in self.active_connections: - return - self.active_connections[transcript_id].remove(websocket) - if not self.active_connections[transcript_id]: - del self.active_connections[transcript_id] - - async def send_json(self, transcript_id: str, message): - if transcript_id not in self.active_connections: - return - for connection in self.active_connections[transcript_id][:]: - try: - await connection.send_json(message) - except Exception: - self.active_connections[transcript_id].remove(connection) - - -ws_manager = WebsocketManager() - - @router.websocket("/transcripts/{transcript_id}/events") async def transcript_events_websocket( transcript_id: str, @@ -532,21 +504,25 @@ async def transcript_events_websocket( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - await ws_manager.connect(transcript_id, websocket) + # connect to websocket manager + # use ts:transcript_id as room id + room_id = f"ts:{transcript_id}" + await ws_manager.add_user_to_room(room_id, websocket) - # on first connection, send all events - for event in transcript.events: - await websocket.send_json(event.model_dump(mode="json")) - - # XXX if transcript is final (locked=True and status=ended) - # XXX send a final event to the client and close the connection - - # endless loop to wait for new events try: + # on first connection, send all events only to the current user + for event in transcript.events: + await websocket.send_json(event.model_dump(mode="json")) + + # XXX if transcript is final (locked=True and status=ended) + # XXX send a final event to the client and close the connection + + # endless loop to wait for new events + # we do not have command system now, while True: await websocket.receive() except (RuntimeError, WebSocketDisconnect): - ws_manager.disconnect(transcript_id, websocket) + await ws_manager.remove_user_from_room(room_id, websocket) # ============================================================== @@ -658,7 +634,8 @@ async def handle_rtc_event_once(event: PipelineEvent, args, data): return # transmit to websocket clients - await ws_manager.send_json(transcript_id, resp.model_dump(mode="json")) + room_id = f"ts:{transcript_id}" + await ws_manager.send_json(room_id, resp.model_dump(mode="json")) @router.post("/transcripts/{transcript_id}/record/webrtc") diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py new file mode 100644 index 00000000..43475c1d --- /dev/null +++ b/server/reflector/ws_manager.py @@ -0,0 +1,127 @@ +""" +Websocket manager +================= + +This module contains the WebsocketManager class, which is responsible for +managing websockets and handling websocket connections. + +It uses the RedisPubSubManager class to subscribe to Redis channels and +broadcast messages to all connected websockets. +""" + +import asyncio +import json + +import redis.asyncio as redis +from fastapi import WebSocket + +ws_manager = None + + +class RedisPubSubManager: + def __init__(self, host="localhost", port=6379): + self.redis_host = host + self.redis_port = port + self.redis_connection = None + self.pubsub = None + + async def get_redis_connection(self) -> redis.Redis: + return redis.Redis( + host=self.redis_host, + port=self.redis_port, + auto_close_connection_pool=False, + ) + + async def connect(self) -> None: + self.redis_connection = await self.get_redis_connection() + self.pubsub = self.redis_connection.pubsub() + + async def disconnect(self) -> None: + if self.redis_connection is None: + return + await self.redis_connection.close() + self.redis_connection = None + + async def send_json(self, room_id: str, message: str) -> None: + message = json.dumps(message) + await self.redis_connection.publish(room_id, message) + + async def subscribe(self, room_id: str) -> redis.Redis: + await self.pubsub.subscribe(room_id) + return self.pubsub + + async def unsubscribe(self, room_id: str) -> None: + await self.pubsub.unsubscribe(room_id) + + +class WebsocketManager: + def __init__(self, pubsub_client: RedisPubSubManager = None): + self.rooms: dict = {} + self.pubsub_client = pubsub_client + + async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: + await websocket.accept() + + if room_id in self.rooms: + self.rooms[room_id].append(websocket) + else: + self.rooms[room_id] = [websocket] + + await self.pubsub_client.connect() + pubsub_subscriber = await self.pubsub_client.subscribe(room_id) + asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber)) + + async def send_json(self, room_id: str, message: dict) -> None: + await self.pubsub_client.send_json(room_id, message) + + async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None: + self.rooms[room_id].remove(websocket) + + if len(self.rooms[room_id]) == 0: + del self.rooms[room_id] + await self.pubsub_client.unsubscribe(room_id) + + async def _pubsub_data_reader(self, pubsub_subscriber): + while True: + message = await pubsub_subscriber.get_message( + ignore_subscribe_messages=True + ) + if message is not None: + room_id = message["channel"].decode("utf-8") + all_sockets = self.rooms[room_id] + for socket in all_sockets: + data = json.loads(message["data"].decode("utf-8")) + await socket.send_json(data) + + +def get_pubsub_client() -> RedisPubSubManager: + """ + Returns the RedisPubSubManager instance for managing Redis pubsub. + """ + from reflector.settings import settings + + return RedisPubSubManager( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + ) + + +def get_ws_manager() -> WebsocketManager: + """ + Returns the WebsocketManager instance for managing websockets. + + This function initializes and returns the WebsocketManager instance, + which is responsible for managing websockets and handling websocket + connections. + + Returns: + WebsocketManager: The initialized WebsocketManager instance. + + Raises: + ImportError: If the 'reflector.settings' module cannot be imported. + RedisConnectionError: If there is an error connecting to the Redis server. + """ + global ws_manager + pubsub_client = get_pubsub_client() + ws_manager = WebsocketManager(pubsub_client=pubsub_client) + return ws_manager