mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
* Update transcript list on reprocess * Fix transcript create * Fix multiple sockets issue * Pass token in sec websocket protocol * userEvent parse example * transcript list invalidation non-abstraction * Emit only relevant events to the user room * Add ws close code const * Refactor user websocket endpoint * Refactor user events provider --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
157 lines
4.9 KiB
Python
157 lines
4.9 KiB
Python
import asyncio
|
|
import threading
|
|
import time
|
|
|
|
import pytest
|
|
from httpx_ws import aconnect_ws
|
|
from uvicorn import Config, Server
|
|
|
|
|
|
@pytest.fixture
|
|
def appserver_ws_user(setup_database):
|
|
from reflector.app import app
|
|
from reflector.db import get_database
|
|
|
|
host = "127.0.0.1"
|
|
port = 1257
|
|
server_started = threading.Event()
|
|
server_exception = None
|
|
server_instance = None
|
|
|
|
def run_server():
|
|
nonlocal server_exception, server_instance
|
|
try:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
config = Config(app=app, host=host, port=port, loop=loop)
|
|
server_instance = Server(config)
|
|
|
|
async def start_server():
|
|
database = get_database()
|
|
await database.connect()
|
|
try:
|
|
await server_instance.serve()
|
|
finally:
|
|
await database.disconnect()
|
|
|
|
server_started.set()
|
|
loop.run_until_complete(start_server())
|
|
except Exception as e:
|
|
server_exception = e
|
|
server_started.set()
|
|
finally:
|
|
loop.close()
|
|
|
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
|
server_thread.start()
|
|
|
|
server_started.wait(timeout=30)
|
|
if server_exception:
|
|
raise server_exception
|
|
|
|
time.sleep(0.5)
|
|
|
|
yield host, port
|
|
|
|
if server_instance:
|
|
server_instance.should_exit = True
|
|
server_thread.join(timeout=30)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_jwt_verification(monkeypatch):
|
|
"""Patch JWT verification to accept HS256 tokens signed with SECRET_KEY for tests."""
|
|
from jose import jwt
|
|
|
|
from reflector.settings import settings
|
|
|
|
def _verify_token(self, token: str):
|
|
# Do not validate audience in tests
|
|
return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) # type: ignore[arg-type]
|
|
|
|
monkeypatch.setattr(
|
|
"reflector.auth.auth_jwt.JWTAuth.verify_token", _verify_token, raising=True
|
|
)
|
|
|
|
|
|
def _make_dummy_jwt(sub: str = "user123") -> str:
|
|
# Create a short HS256 JWT using the app secret to pass verification in tests
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from jose import jwt
|
|
|
|
from reflector.settings import settings
|
|
|
|
payload = {
|
|
"sub": sub,
|
|
"email": f"{sub}@example.com",
|
|
"exp": datetime.now(timezone.utc) + timedelta(minutes=5),
|
|
}
|
|
# Note: production uses RS256 public key verification; tests can sign with SECRET_KEY
|
|
return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_ws_rejects_missing_subprotocol(appserver_ws_user):
|
|
host, port = appserver_ws_user
|
|
base_ws = f"http://{host}:{port}/v1/events"
|
|
# No subprotocol/header with token
|
|
with pytest.raises(Exception):
|
|
async with aconnect_ws(base_ws) as ws: # type: ignore
|
|
# Should close during handshake; if not, close explicitly
|
|
await ws.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_ws_rejects_invalid_token(appserver_ws_user):
|
|
host, port = appserver_ws_user
|
|
base_ws = f"http://{host}:{port}/v1/events"
|
|
|
|
# Send wrong token via WebSocket subprotocols
|
|
protocols = ["bearer", "totally-invalid-token"]
|
|
with pytest.raises(Exception):
|
|
async with aconnect_ws(base_ws, subprotocols=protocols) as ws: # type: ignore
|
|
await ws.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user):
|
|
host, port = appserver_ws_user
|
|
base_ws = f"http://{host}:{port}/v1/events"
|
|
|
|
token = _make_dummy_jwt("user-abc")
|
|
subprotocols = ["bearer", token]
|
|
|
|
# Connect and then trigger an event via HTTP create
|
|
async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws:
|
|
# Emit an event to the user's room via a standard HTTP action
|
|
from httpx import AsyncClient
|
|
|
|
from reflector.app import app
|
|
from reflector.auth import current_user, current_user_optional
|
|
|
|
# Override auth dependencies so HTTP request is performed as the same user
|
|
app.dependency_overrides[current_user] = lambda: {
|
|
"sub": "user-abc",
|
|
"email": "user-abc@example.com",
|
|
}
|
|
app.dependency_overrides[current_user_optional] = lambda: {
|
|
"sub": "user-abc",
|
|
"email": "user-abc@example.com",
|
|
}
|
|
|
|
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
|
|
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
|
|
resp = await ac.post("/transcripts", json={"name": "WS Test"})
|
|
assert resp.status_code == 200
|
|
|
|
# Receive the published event
|
|
msg = await ws.receive_json()
|
|
assert msg["event"] == "TRANSCRIPT_CREATED"
|
|
assert "id" in msg["data"]
|
|
|
|
# Clean overrides
|
|
del app.dependency_overrides[current_user]
|
|
del app.dependency_overrides[current_user_optional]
|