mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: use redis pubsub for interprocess websocket communication
This commit is contained in:
@@ -1,15 +1,19 @@
|
|||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
services:
|
||||||
server:
|
# server:
|
||||||
build:
|
# build:
|
||||||
context: .
|
# context: .
|
||||||
|
# ports:
|
||||||
|
# - 1250:1250
|
||||||
|
# environment:
|
||||||
|
# LLM_URL: "${LLM_URL}"
|
||||||
|
# MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
|
||||||
|
# volumes:
|
||||||
|
# - model-cache:/root/.cache
|
||||||
|
redis:
|
||||||
|
image: redis:7.2
|
||||||
ports:
|
ports:
|
||||||
- 1250:1250
|
- 6379:6379
|
||||||
environment:
|
|
||||||
LLM_URL: "${LLM_URL}"
|
|
||||||
MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
|
|
||||||
volumes:
|
|
||||||
- model-cache:/root/.cache
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
model-cache:
|
model-cache:
|
||||||
|
|||||||
20
server/poetry.lock
generated
20
server/poetry.lock
generated
@@ -2919,6 +2919,24 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
full = ["numpy"]
|
full = ["numpy"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redis"
|
||||||
|
version = "5.0.1"
|
||||||
|
description = "Python client for Redis database and key-value store"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"},
|
||||||
|
{file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
hiredis = ["hiredis (>=1.0.0)"]
|
||||||
|
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "2023.10.3"
|
version = "2023.10.3"
|
||||||
@@ -4046,4 +4064,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c"
|
content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63"
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ sentencepiece = "^0.1.99"
|
|||||||
protobuf = "^4.24.3"
|
protobuf = "^4.24.3"
|
||||||
profanityfilter = "^2.0.6"
|
profanityfilter = "^2.0.6"
|
||||||
celery = "^5.3.4"
|
celery = "^5.3.4"
|
||||||
|
redis = "^5.0.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|||||||
@@ -117,5 +117,9 @@ class Settings(BaseSettings):
|
|||||||
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
REDIS_HOST: str = "localhost"
|
||||||
|
REDIS_PORT: int = 6379
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -21,12 +21,14 @@ from reflector.processors.types import Transcript as ProcessorTranscript
|
|||||||
from reflector.processors.types import Word as ProcessorWord
|
from reflector.processors.types import Word as ProcessorWord
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.audio_waveform import get_audio_waveform
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
from reflector.ws_manager import get_ws_manager
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
from ._range_requests_response import range_requests_response
|
from ._range_requests_response import range_requests_response
|
||||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
ws_manager = get_ws_manager()
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
# Models to move to a database, but required for the API to work
|
# Models to move to a database, but required for the API to work
|
||||||
@@ -487,40 +489,10 @@ async def transcript_get_websocket_events(transcript_id: str):
|
|||||||
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
# Websocket Manager
|
# Websocket
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
|
|
||||||
|
|
||||||
class WebsocketManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.active_connections = {}
|
|
||||||
|
|
||||||
async def connect(self, transcript_id: str, websocket: WebSocket):
|
|
||||||
await websocket.accept()
|
|
||||||
if transcript_id not in self.active_connections:
|
|
||||||
self.active_connections[transcript_id] = []
|
|
||||||
self.active_connections[transcript_id].append(websocket)
|
|
||||||
|
|
||||||
def disconnect(self, transcript_id: str, websocket: WebSocket):
|
|
||||||
if transcript_id not in self.active_connections:
|
|
||||||
return
|
|
||||||
self.active_connections[transcript_id].remove(websocket)
|
|
||||||
if not self.active_connections[transcript_id]:
|
|
||||||
del self.active_connections[transcript_id]
|
|
||||||
|
|
||||||
async def send_json(self, transcript_id: str, message):
|
|
||||||
if transcript_id not in self.active_connections:
|
|
||||||
return
|
|
||||||
for connection in self.active_connections[transcript_id][:]:
|
|
||||||
try:
|
|
||||||
await connection.send_json(message)
|
|
||||||
except Exception:
|
|
||||||
self.active_connections[transcript_id].remove(connection)
|
|
||||||
|
|
||||||
|
|
||||||
ws_manager = WebsocketManager()
|
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/transcripts/{transcript_id}/events")
|
@router.websocket("/transcripts/{transcript_id}/events")
|
||||||
async def transcript_events_websocket(
|
async def transcript_events_websocket(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
@@ -532,9 +504,13 @@ async def transcript_events_websocket(
|
|||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
await ws_manager.connect(transcript_id, websocket)
|
# connect to websocket manager
|
||||||
|
# use ts:transcript_id as room id
|
||||||
|
room_id = f"ts:{transcript_id}"
|
||||||
|
await ws_manager.add_user_to_room(room_id, websocket)
|
||||||
|
|
||||||
# on first connection, send all events
|
try:
|
||||||
|
# on first connection, send all events only to the current user
|
||||||
for event in transcript.events:
|
for event in transcript.events:
|
||||||
await websocket.send_json(event.model_dump(mode="json"))
|
await websocket.send_json(event.model_dump(mode="json"))
|
||||||
|
|
||||||
@@ -542,11 +518,11 @@ async def transcript_events_websocket(
|
|||||||
# XXX send a final event to the client and close the connection
|
# XXX send a final event to the client and close the connection
|
||||||
|
|
||||||
# endless loop to wait for new events
|
# endless loop to wait for new events
|
||||||
try:
|
# we do not have command system now,
|
||||||
while True:
|
while True:
|
||||||
await websocket.receive()
|
await websocket.receive()
|
||||||
except (RuntimeError, WebSocketDisconnect):
|
except (RuntimeError, WebSocketDisconnect):
|
||||||
ws_manager.disconnect(transcript_id, websocket)
|
await ws_manager.remove_user_from_room(room_id, websocket)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
@@ -658,7 +634,8 @@ async def handle_rtc_event_once(event: PipelineEvent, args, data):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# transmit to websocket clients
|
# transmit to websocket clients
|
||||||
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
room_id = f"ts:{transcript_id}"
|
||||||
|
await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||||
|
|||||||
127
server/reflector/ws_manager.py
Normal file
127
server/reflector/ws_manager.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
Websocket manager
|
||||||
|
=================
|
||||||
|
|
||||||
|
This module contains the WebsocketManager class, which is responsible for
|
||||||
|
managing websockets and handling websocket connections.
|
||||||
|
|
||||||
|
It uses the RedisPubSubManager class to subscribe to Redis channels and
|
||||||
|
broadcast messages to all connected websockets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
ws_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
class RedisPubSubManager:
|
||||||
|
def __init__(self, host="localhost", port=6379):
|
||||||
|
self.redis_host = host
|
||||||
|
self.redis_port = port
|
||||||
|
self.redis_connection = None
|
||||||
|
self.pubsub = None
|
||||||
|
|
||||||
|
async def get_redis_connection(self) -> redis.Redis:
|
||||||
|
return redis.Redis(
|
||||||
|
host=self.redis_host,
|
||||||
|
port=self.redis_port,
|
||||||
|
auto_close_connection_pool=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self.redis_connection = await self.get_redis_connection()
|
||||||
|
self.pubsub = self.redis_connection.pubsub()
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
if self.redis_connection is None:
|
||||||
|
return
|
||||||
|
await self.redis_connection.close()
|
||||||
|
self.redis_connection = None
|
||||||
|
|
||||||
|
async def send_json(self, room_id: str, message: str) -> None:
|
||||||
|
message = json.dumps(message)
|
||||||
|
await self.redis_connection.publish(room_id, message)
|
||||||
|
|
||||||
|
async def subscribe(self, room_id: str) -> redis.Redis:
|
||||||
|
await self.pubsub.subscribe(room_id)
|
||||||
|
return self.pubsub
|
||||||
|
|
||||||
|
async def unsubscribe(self, room_id: str) -> None:
|
||||||
|
await self.pubsub.unsubscribe(room_id)
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketManager:
|
||||||
|
def __init__(self, pubsub_client: RedisPubSubManager = None):
|
||||||
|
self.rooms: dict = {}
|
||||||
|
self.pubsub_client = pubsub_client
|
||||||
|
|
||||||
|
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
if room_id in self.rooms:
|
||||||
|
self.rooms[room_id].append(websocket)
|
||||||
|
else:
|
||||||
|
self.rooms[room_id] = [websocket]
|
||||||
|
|
||||||
|
await self.pubsub_client.connect()
|
||||||
|
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
||||||
|
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
||||||
|
|
||||||
|
async def send_json(self, room_id: str, message: dict) -> None:
|
||||||
|
await self.pubsub_client.send_json(room_id, message)
|
||||||
|
|
||||||
|
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||||
|
self.rooms[room_id].remove(websocket)
|
||||||
|
|
||||||
|
if len(self.rooms[room_id]) == 0:
|
||||||
|
del self.rooms[room_id]
|
||||||
|
await self.pubsub_client.unsubscribe(room_id)
|
||||||
|
|
||||||
|
async def _pubsub_data_reader(self, pubsub_subscriber):
|
||||||
|
while True:
|
||||||
|
message = await pubsub_subscriber.get_message(
|
||||||
|
ignore_subscribe_messages=True
|
||||||
|
)
|
||||||
|
if message is not None:
|
||||||
|
room_id = message["channel"].decode("utf-8")
|
||||||
|
all_sockets = self.rooms[room_id]
|
||||||
|
for socket in all_sockets:
|
||||||
|
data = json.loads(message["data"].decode("utf-8"))
|
||||||
|
await socket.send_json(data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pubsub_client() -> RedisPubSubManager:
|
||||||
|
"""
|
||||||
|
Returns the RedisPubSubManager instance for managing Redis pubsub.
|
||||||
|
"""
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
return RedisPubSubManager(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ws_manager() -> WebsocketManager:
|
||||||
|
"""
|
||||||
|
Returns the WebsocketManager instance for managing websockets.
|
||||||
|
|
||||||
|
This function initializes and returns the WebsocketManager instance,
|
||||||
|
which is responsible for managing websockets and handling websocket
|
||||||
|
connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WebsocketManager: The initialized WebsocketManager instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the 'reflector.settings' module cannot be imported.
|
||||||
|
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||||
|
"""
|
||||||
|
global ws_manager
|
||||||
|
pubsub_client = get_pubsub_client()
|
||||||
|
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||||
|
return ws_manager
|
||||||
Reference in New Issue
Block a user