mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: use redis pubsub for interprocess websocket communication
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
server:
|
||||
build:
|
||||
context: .
|
||||
# server:
|
||||
# build:
|
||||
# context: .
|
||||
# ports:
|
||||
# - 1250:1250
|
||||
# environment:
|
||||
# LLM_URL: "${LLM_URL}"
|
||||
# MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
|
||||
# volumes:
|
||||
# - model-cache:/root/.cache
|
||||
redis:
|
||||
image: redis:7.2
|
||||
ports:
|
||||
- 1250:1250
|
||||
environment:
|
||||
LLM_URL: "${LLM_URL}"
|
||||
MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
|
||||
volumes:
|
||||
- model-cache:/root/.cache
|
||||
- 6379:6379
|
||||
|
||||
volumes:
|
||||
model-cache:
|
||||
|
||||
20
server/poetry.lock
generated
20
server/poetry.lock
generated
@@ -2919,6 +2919,24 @@ files = [
|
||||
[package.extras]
|
||||
full = ["numpy"]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.0.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"},
|
||||
{file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""}
|
||||
|
||||
[package.extras]
|
||||
hiredis = ["hiredis (>=1.0.0)"]
|
||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "2023.10.3"
|
||||
@@ -4046,4 +4064,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c"
|
||||
content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63"
|
||||
|
||||
@@ -34,6 +34,7 @@ sentencepiece = "^0.1.99"
|
||||
protobuf = "^4.24.3"
|
||||
profanityfilter = "^2.0.6"
|
||||
celery = "^5.3.4"
|
||||
redis = "^5.0.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -117,5 +117,9 @@ class Settings(BaseSettings):
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
|
||||
# Redis
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -21,12 +21,14 @@ from reflector.processors.types import Transcript as ProcessorTranscript
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
|
||||
router = APIRouter()
|
||||
ws_manager = get_ws_manager()
|
||||
|
||||
# ==============================================================
|
||||
# Models to move to a database, but required for the API to work
|
||||
@@ -487,40 +489,10 @@ async def transcript_get_websocket_events(transcript_id: str):
|
||||
|
||||
|
||||
# ==============================================================
|
||||
# Websocket Manager
|
||||
# Websocket
|
||||
# ==============================================================
|
||||
|
||||
|
||||
class WebsocketManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
|
||||
async def connect(self, transcript_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
if transcript_id not in self.active_connections:
|
||||
self.active_connections[transcript_id] = []
|
||||
self.active_connections[transcript_id].append(websocket)
|
||||
|
||||
def disconnect(self, transcript_id: str, websocket: WebSocket):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
self.active_connections[transcript_id].remove(websocket)
|
||||
if not self.active_connections[transcript_id]:
|
||||
del self.active_connections[transcript_id]
|
||||
|
||||
async def send_json(self, transcript_id: str, message):
|
||||
if transcript_id not in self.active_connections:
|
||||
return
|
||||
for connection in self.active_connections[transcript_id][:]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
self.active_connections[transcript_id].remove(connection)
|
||||
|
||||
|
||||
ws_manager = WebsocketManager()
|
||||
|
||||
|
||||
@router.websocket("/transcripts/{transcript_id}/events")
|
||||
async def transcript_events_websocket(
|
||||
transcript_id: str,
|
||||
@@ -532,21 +504,25 @@ async def transcript_events_websocket(
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
await ws_manager.connect(transcript_id, websocket)
|
||||
# connect to websocket manager
|
||||
# use ts:transcript_id as room id
|
||||
room_id = f"ts:{transcript_id}"
|
||||
await ws_manager.add_user_to_room(room_id, websocket)
|
||||
|
||||
# on first connection, send all events
|
||||
for event in transcript.events:
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
|
||||
# endless loop to wait for new events
|
||||
try:
|
||||
# on first connection, send all events only to the current user
|
||||
for event in transcript.events:
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
|
||||
# endless loop to wait for new events
|
||||
# we do not have command system now,
|
||||
while True:
|
||||
await websocket.receive()
|
||||
except (RuntimeError, WebSocketDisconnect):
|
||||
ws_manager.disconnect(transcript_id, websocket)
|
||||
await ws_manager.remove_user_from_room(room_id, websocket)
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -658,7 +634,8 @@ async def handle_rtc_event_once(event: PipelineEvent, args, data):
|
||||
return
|
||||
|
||||
# transmit to websocket clients
|
||||
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
||||
room_id = f"ts:{transcript_id}"
|
||||
await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
|
||||
127
server/reflector/ws_manager.py
Normal file
127
server/reflector/ws_manager.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Websocket manager
|
||||
=================
|
||||
|
||||
This module contains the WebsocketManager class, which is responsible for
|
||||
managing websockets and handling websocket connections.
|
||||
|
||||
It uses the RedisPubSubManager class to subscribe to Redis channels and
|
||||
broadcast messages to all connected websockets.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import redis.asyncio as redis
|
||||
from fastapi import WebSocket
|
||||
|
||||
ws_manager = None
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
def __init__(self, host="localhost", port=6379):
|
||||
self.redis_host = host
|
||||
self.redis_port = port
|
||||
self.redis_connection = None
|
||||
self.pubsub = None
|
||||
|
||||
async def get_redis_connection(self) -> redis.Redis:
|
||||
return redis.Redis(
|
||||
host=self.redis_host,
|
||||
port=self.redis_port,
|
||||
auto_close_connection_pool=False,
|
||||
)
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.redis_connection = await self.get_redis_connection()
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.redis_connection is None:
|
||||
return
|
||||
await self.redis_connection.close()
|
||||
self.redis_connection = None
|
||||
|
||||
async def send_json(self, room_id: str, message: str) -> None:
|
||||
message = json.dumps(message)
|
||||
await self.redis_connection.publish(room_id, message)
|
||||
|
||||
async def subscribe(self, room_id: str) -> redis.Redis:
|
||||
await self.pubsub.subscribe(room_id)
|
||||
return self.pubsub
|
||||
|
||||
async def unsubscribe(self, room_id: str) -> None:
|
||||
await self.pubsub.unsubscribe(room_id)
|
||||
|
||||
|
||||
class WebsocketManager:
|
||||
def __init__(self, pubsub_client: RedisPubSubManager = None):
|
||||
self.rooms: dict = {}
|
||||
self.pubsub_client = pubsub_client
|
||||
|
||||
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
if room_id in self.rooms:
|
||||
self.rooms[room_id].append(websocket)
|
||||
else:
|
||||
self.rooms[room_id] = [websocket]
|
||||
|
||||
await self.pubsub_client.connect()
|
||||
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
||||
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
||||
|
||||
async def send_json(self, room_id: str, message: dict) -> None:
|
||||
await self.pubsub_client.send_json(room_id, message)
|
||||
|
||||
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
self.rooms[room_id].remove(websocket)
|
||||
|
||||
if len(self.rooms[room_id]) == 0:
|
||||
del self.rooms[room_id]
|
||||
await self.pubsub_client.unsubscribe(room_id)
|
||||
|
||||
async def _pubsub_data_reader(self, pubsub_subscriber):
|
||||
while True:
|
||||
message = await pubsub_subscriber.get_message(
|
||||
ignore_subscribe_messages=True
|
||||
)
|
||||
if message is not None:
|
||||
room_id = message["channel"].decode("utf-8")
|
||||
all_sockets = self.rooms[room_id]
|
||||
for socket in all_sockets:
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
await socket.send_json(data)
|
||||
|
||||
|
||||
def get_pubsub_client() -> RedisPubSubManager:
|
||||
"""
|
||||
Returns the RedisPubSubManager instance for managing Redis pubsub.
|
||||
"""
|
||||
from reflector.settings import settings
|
||||
|
||||
return RedisPubSubManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
)
|
||||
|
||||
|
||||
def get_ws_manager() -> WebsocketManager:
|
||||
"""
|
||||
Returns the WebsocketManager instance for managing websockets.
|
||||
|
||||
This function initializes and returns the WebsocketManager instance,
|
||||
which is responsible for managing websockets and handling websocket
|
||||
connections.
|
||||
|
||||
Returns:
|
||||
WebsocketManager: The initialized WebsocketManager instance.
|
||||
|
||||
Raises:
|
||||
ImportError: If the 'reflector.settings' module cannot be imported.
|
||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||
"""
|
||||
global ws_manager
|
||||
pubsub_client = get_pubsub_client()
|
||||
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||
return ws_manager
|
||||
Reference in New Issue
Block a user