server: use redis pubsub for interprocess websocket communication

This commit is contained in:
2023-10-25 16:56:43 +02:00
committed by Mathieu Virbel
parent 8bebb2a769
commit 00c06b7971
6 changed files with 183 additions and 52 deletions

View File

@@ -1,15 +1,19 @@
version: "3.9"
services:
server:
build:
context: .
# server:
# build:
# context: .
# ports:
# - 1250:1250
# environment:
# LLM_URL: "${LLM_URL}"
# MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
# volumes:
# - model-cache:/root/.cache
redis:
image: redis:7.2
ports:
- 1250:1250
environment:
LLM_URL: "${LLM_URL}"
MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
volumes:
- model-cache:/root/.cache
- 6379:6379
volumes:
model-cache:

20
server/poetry.lock generated
View File

@@ -2919,6 +2919,24 @@ files = [
[package.extras]
full = ["numpy"]
[[package]]
name = "redis"
version = "5.0.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.7"
files = [
{file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"},
{file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""}
[package.extras]
hiredis = ["hiredis (>=1.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
[[package]]
name = "regex"
version = "2023.10.3"
@@ -4046,4 +4064,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c"
content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63"

View File

@@ -34,6 +34,7 @@ sentencepiece = "^0.1.99"
protobuf = "^4.24.3"
profanityfilter = "^2.0.6"
celery = "^5.3.4"
redis = "^5.0.1"
[tool.poetry.group.dev.dependencies]

View File

@@ -117,5 +117,9 @@ class Settings(BaseSettings):
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
# Redis
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
settings = Settings()

View File

@@ -21,12 +21,14 @@ from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
ws_manager = get_ws_manager()
# ==============================================================
# Models to move to a database, but required for the API to work
@@ -487,40 +489,10 @@ async def transcript_get_websocket_events(transcript_id: str):
# ==============================================================
# Websocket Manager
# Websocket
# ==============================================================
class WebsocketManager:
def __init__(self):
self.active_connections = {}
async def connect(self, transcript_id: str, websocket: WebSocket):
await websocket.accept()
if transcript_id not in self.active_connections:
self.active_connections[transcript_id] = []
self.active_connections[transcript_id].append(websocket)
def disconnect(self, transcript_id: str, websocket: WebSocket):
if transcript_id not in self.active_connections:
return
self.active_connections[transcript_id].remove(websocket)
if not self.active_connections[transcript_id]:
del self.active_connections[transcript_id]
async def send_json(self, transcript_id: str, message):
if transcript_id not in self.active_connections:
return
for connection in self.active_connections[transcript_id][:]:
try:
await connection.send_json(message)
except Exception:
self.active_connections[transcript_id].remove(connection)
ws_manager = WebsocketManager()
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
@@ -532,21 +504,25 @@ async def transcript_events_websocket(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await ws_manager.connect(transcript_id, websocket)
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
await ws_manager.add_user_to_room(room_id, websocket)
# on first connection, send all events
for event in transcript.events:
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
try:
# on first connection, send all events only to the current user
for event in transcript.events:
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
ws_manager.disconnect(transcript_id, websocket)
await ws_manager.remove_user_from_room(room_id, websocket)
# ==============================================================
@@ -658,7 +634,8 @@ async def handle_rtc_event_once(event: PipelineEvent, args, data):
return
# transmit to websocket clients
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
room_id = f"ts:{transcript_id}"
await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc")

View File

@@ -0,0 +1,127 @@
"""
Websocket manager
=================
This module contains the WebsocketManager class, which is responsible for
managing websockets and handling websocket connections.
It uses the RedisPubSubManager class to subscribe to Redis channels and
broadcast messages to all connected websockets.
"""
import asyncio
import json
import redis.asyncio as redis
from fastapi import WebSocket
ws_manager = None
class RedisPubSubManager:
def __init__(self, host="localhost", port=6379):
self.redis_host = host
self.redis_port = port
self.redis_connection = None
self.pubsub = None
async def get_redis_connection(self) -> redis.Redis:
return redis.Redis(
host=self.redis_host,
port=self.redis_port,
auto_close_connection_pool=False,
)
async def connect(self) -> None:
self.redis_connection = await self.get_redis_connection()
self.pubsub = self.redis_connection.pubsub()
async def disconnect(self) -> None:
if self.redis_connection is None:
return
await self.redis_connection.close()
self.redis_connection = None
async def send_json(self, room_id: str, message: str) -> None:
message = json.dumps(message)
await self.redis_connection.publish(room_id, message)
async def subscribe(self, room_id: str) -> redis.Redis:
await self.pubsub.subscribe(room_id)
return self.pubsub
async def unsubscribe(self, room_id: str) -> None:
await self.pubsub.unsubscribe(room_id)
class WebsocketManager:
def __init__(self, pubsub_client: RedisPubSubManager = None):
self.rooms: dict = {}
self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
await websocket.accept()
if room_id in self.rooms:
self.rooms[room_id].append(websocket)
else:
self.rooms[room_id] = [websocket]
await self.pubsub_client.connect()
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
async def send_json(self, room_id: str, message: dict) -> None:
await self.pubsub_client.send_json(room_id, message)
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
self.rooms[room_id].remove(websocket)
if len(self.rooms[room_id]) == 0:
del self.rooms[room_id]
await self.pubsub_client.unsubscribe(room_id)
async def _pubsub_data_reader(self, pubsub_subscriber):
while True:
message = await pubsub_subscriber.get_message(
ignore_subscribe_messages=True
)
if message is not None:
room_id = message["channel"].decode("utf-8")
all_sockets = self.rooms[room_id]
for socket in all_sockets:
data = json.loads(message["data"].decode("utf-8"))
await socket.send_json(data)
def get_pubsub_client() -> RedisPubSubManager:
"""
Returns the RedisPubSubManager instance for managing Redis pubsub.
"""
from reflector.settings import settings
return RedisPubSubManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
)
def get_ws_manager() -> WebsocketManager:
"""
Returns the WebsocketManager instance for managing websockets.
This function initializes and returns the WebsocketManager instance,
which is responsible for managing websockets and handling websocket
connections.
Returns:
WebsocketManager: The initialized WebsocketManager instance.
Raises:
ImportError: If the 'reflector.settings' module cannot be imported.
RedisConnectionError: If there is an error connecting to the Redis server.
"""
global ws_manager
pubsub_client = get_pubsub_client()
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
return ws_manager