mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
fix: update transcript list on reprocess (#676)
* Update transcript list on reprocess * Fix transcript create * Fix multiple sockets issue * Pass token in sec websocket protocol * userEvent parse example * transcript list invalidation non-abstraction * Emit only relevant events to the user room * Add ws close code const * Refactor user websocket endpoint * Refactor user events provider --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
156
server/tests/test_user_websocket_auth.py
Normal file
156
server/tests/test_user_websocket_auth.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from httpx_ws import aconnect_ws
|
||||
from uvicorn import Config, Server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def appserver_ws_user(setup_database):
|
||||
from reflector.app import app
|
||||
from reflector.db import get_database
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = 1257
|
||||
server_started = threading.Event()
|
||||
server_exception = None
|
||||
server_instance = None
|
||||
|
||||
def run_server():
|
||||
nonlocal server_exception, server_instance
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
config = Config(app=app, host=host, port=port, loop=loop)
|
||||
server_instance = Server(config)
|
||||
|
||||
async def start_server():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
await server_instance.serve()
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
server_started.set()
|
||||
loop.run_until_complete(start_server())
|
||||
except Exception as e:
|
||||
server_exception = e
|
||||
server_started.set()
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
server_started.wait(timeout=30)
|
||||
if server_exception:
|
||||
raise server_exception
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
yield host, port
|
||||
|
||||
if server_instance:
|
||||
server_instance.should_exit = True
|
||||
server_thread.join(timeout=30)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_jwt_verification(monkeypatch):
|
||||
"""Patch JWT verification to accept HS256 tokens signed with SECRET_KEY for tests."""
|
||||
from jose import jwt
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
def _verify_token(self, token: str):
|
||||
# Do not validate audience in tests
|
||||
return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) # type: ignore[arg-type]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"reflector.auth.auth_jwt.JWTAuth.verify_token", _verify_token, raising=True
|
||||
)
|
||||
|
||||
|
||||
def _make_dummy_jwt(sub: str = "user123") -> str:
|
||||
# Create a short HS256 JWT using the app secret to pass verification in tests
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
payload = {
|
||||
"sub": sub,
|
||||
"email": f"{sub}@example.com",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||
}
|
||||
# Note: production uses RS256 public key verification; tests can sign with SECRET_KEY
|
||||
return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_ws_rejects_missing_subprotocol(appserver_ws_user):
|
||||
host, port = appserver_ws_user
|
||||
base_ws = f"http://{host}:{port}/v1/events"
|
||||
# No subprotocol/header with token
|
||||
with pytest.raises(Exception):
|
||||
async with aconnect_ws(base_ws) as ws: # type: ignore
|
||||
# Should close during handshake; if not, close explicitly
|
||||
await ws.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_ws_rejects_invalid_token(appserver_ws_user):
|
||||
host, port = appserver_ws_user
|
||||
base_ws = f"http://{host}:{port}/v1/events"
|
||||
|
||||
# Send wrong token via WebSocket subprotocols
|
||||
protocols = ["bearer", "totally-invalid-token"]
|
||||
with pytest.raises(Exception):
|
||||
async with aconnect_ws(base_ws, subprotocols=protocols) as ws: # type: ignore
|
||||
await ws.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user):
|
||||
host, port = appserver_ws_user
|
||||
base_ws = f"http://{host}:{port}/v1/events"
|
||||
|
||||
token = _make_dummy_jwt("user-abc")
|
||||
subprotocols = ["bearer", token]
|
||||
|
||||
# Connect and then trigger an event via HTTP create
|
||||
async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws:
|
||||
# Emit an event to the user's room via a standard HTTP action
|
||||
from httpx import AsyncClient
|
||||
|
||||
from reflector.app import app
|
||||
from reflector.auth import current_user, current_user_optional
|
||||
|
||||
# Override auth dependencies so HTTP request is performed as the same user
|
||||
app.dependency_overrides[current_user] = lambda: {
|
||||
"sub": "user-abc",
|
||||
"email": "user-abc@example.com",
|
||||
}
|
||||
app.dependency_overrides[current_user_optional] = lambda: {
|
||||
"sub": "user-abc",
|
||||
"email": "user-abc@example.com",
|
||||
}
|
||||
|
||||
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
|
||||
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
|
||||
resp = await ac.post("/transcripts", json={"name": "WS Test"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Receive the published event
|
||||
msg = await ws.receive_json()
|
||||
assert msg["event"] == "TRANSCRIPT_CREATED"
|
||||
assert "id" in msg["data"]
|
||||
|
||||
# Clean overrides
|
||||
del app.dependency_overrides[current_user]
|
||||
del app.dependency_overrides[current_user_optional]
|
||||
Reference in New Issue
Block a user