mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-30 10:56:47 +00:00
test: add WebSocket broadcast delivery tests for STATUS and DAG_STATUS
Exercises the full broadcast → pub/sub → WebSocket delivery chain that DEBUG.md identified as potentially broken. Covers send_json direct delivery, broadcast_event() end-to-end, and event ordering. Also patches broadcast.py's get_ws_manager (missing from conftest).
This commit is contained in:
331
server/tests/test_ws_dag_broadcast.py
Normal file
331
server/tests/test_ws_dag_broadcast.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""WebSocket broadcast delivery tests for STATUS and DAG_STATUS events.
|
||||
|
||||
Tests the full chain identified in DEBUG.md:
|
||||
broadcast_event() → ws_manager.send_json() → Redis/in-memory pub/sub
|
||||
→ _pubsub_data_reader() → socket.send_json() → WebSocket client
|
||||
|
||||
Covers:
|
||||
1. STATUS event delivery to transcript room WS
|
||||
2. DAG_STATUS event delivery to transcript room WS
|
||||
3. Full broadcast_event() chain (requires broadcast.py patching)
|
||||
4. _pubsub_data_reader resilience when a client disconnects
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from httpx_ws import aconnect_ws
|
||||
from uvicorn import Config, Server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def appserver_ws_broadcast(setup_database, monkeypatch):
|
||||
"""Start real uvicorn server for WebSocket broadcast tests.
|
||||
|
||||
Also patches broadcast.py's get_ws_manager (missing from conftest autouse fixture).
|
||||
"""
|
||||
# Patch broadcast.py's get_ws_manager — conftest.py misses this module.
|
||||
# Without this, broadcast_event() creates a real Redis ws_manager.
|
||||
import reflector.ws_manager as ws_mod
|
||||
from reflector.app import app
|
||||
from reflector.db import get_database
|
||||
|
||||
monkeypatch.setattr(
|
||||
"reflector.hatchet.broadcast.get_ws_manager", ws_mod.get_ws_manager
|
||||
)
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = 1259
|
||||
server_started = threading.Event()
|
||||
server_exception = None
|
||||
server_instance = None
|
||||
|
||||
def run_server():
|
||||
nonlocal server_exception, server_instance
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
config = Config(app=app, host=host, port=port, loop=loop)
|
||||
server_instance = Server(config)
|
||||
|
||||
async def start_server():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
await server_instance.serve()
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
server_started.set()
|
||||
loop.run_until_complete(start_server())
|
||||
except Exception as e:
|
||||
server_exception = e
|
||||
server_started.set()
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
server_started.wait(timeout=30)
|
||||
if server_exception:
|
||||
raise server_exception
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
yield host, port
|
||||
|
||||
if server_instance:
|
||||
server_instance.should_exit = True
|
||||
server_thread.join(timeout=2.0)
|
||||
|
||||
from reflector.ws_manager import reset_ws_manager
|
||||
|
||||
reset_ws_manager()
|
||||
|
||||
|
||||
async def _create_transcript(host: str, port: int, name: str) -> str:
|
||||
"""Create a transcript via ASGI transport and return its ID."""
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
|
||||
resp = await ac.post("/transcripts", json={"name": name})
|
||||
assert resp.status_code == 200, f"Failed to create transcript: {resp.text}"
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
async def _drain_historical_events(ws, timeout: float = 0.5) -> list[dict]:
|
||||
"""Read all historical events sent on WS connect (non-blocking drain)."""
|
||||
events = []
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=0.1)
|
||||
events.append(msg)
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
break
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: STATUS event delivery via ws_manager.send_json
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_ws_receives_status_via_send_json(appserver_ws_broadcast):
|
||||
"""STATUS event published via ws_manager.send_json() arrives at transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "Status send_json test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
import reflector.ws_manager as ws_mod
|
||||
|
||||
ws_manager = ws_mod.get_ws_manager()
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={"event": "STATUS", "data": {"value": "processing"}},
|
||||
)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "STATUS"
|
||||
assert msg["data"]["value"] == "processing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: DAG_STATUS event delivery via ws_manager.send_json
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_ws_receives_dag_status_via_send_json(appserver_ws_broadcast):
|
||||
"""DAG_STATUS event published via ws_manager.send_json() arrives at transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "DAG_STATUS send_json test")
|
||||
|
||||
dag_payload = {
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "test-run-123",
|
||||
"tasks": [
|
||||
{
|
||||
"name": "get_recording",
|
||||
"status": "completed",
|
||||
"started_at": "2025-01-01T00:00:00Z",
|
||||
"finished_at": "2025-01-01T00:00:05Z",
|
||||
"duration_seconds": 5.0,
|
||||
"parents": [],
|
||||
"error": None,
|
||||
"children_total": None,
|
||||
"children_completed": None,
|
||||
"progress_pct": None,
|
||||
},
|
||||
{
|
||||
"name": "process_tracks",
|
||||
"status": "running",
|
||||
"started_at": "2025-01-01T00:00:05Z",
|
||||
"finished_at": None,
|
||||
"duration_seconds": None,
|
||||
"parents": ["get_recording"],
|
||||
"error": None,
|
||||
"children_total": 3,
|
||||
"children_completed": 1,
|
||||
"progress_pct": 33.3,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
import reflector.ws_manager as ws_mod
|
||||
|
||||
ws_manager = ws_mod.get_ws_manager()
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message=dag_payload,
|
||||
)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "DAG_STATUS"
|
||||
assert msg["data"]["workflow_run_id"] == "test-run-123"
|
||||
assert len(msg["data"]["tasks"]) == 2
|
||||
assert msg["data"]["tasks"][0]["name"] == "get_recording"
|
||||
assert msg["data"]["tasks"][0]["status"] == "completed"
|
||||
assert msg["data"]["tasks"][1]["name"] == "process_tracks"
|
||||
assert msg["data"]["tasks"][1]["children_completed"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Full broadcast_event() chain for STATUS
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_event_delivers_status_to_transcript_ws(appserver_ws_broadcast):
|
||||
"""broadcast_event() end-to-end: STATUS event reaches transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "broadcast_event STATUS test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
from reflector.db.transcripts import TranscriptEvent
|
||||
from reflector.hatchet.broadcast import broadcast_event
|
||||
from reflector.logger import logger
|
||||
|
||||
log = logger.bind(transcript_id=transcript_id)
|
||||
event = TranscriptEvent(event="STATUS", data={"value": "processing"})
|
||||
await broadcast_event(transcript_id, event, logger=log)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "STATUS"
|
||||
assert msg["data"]["value"] == "processing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Full broadcast_event() chain for DAG_STATUS
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_event_delivers_dag_status_to_transcript_ws(
|
||||
appserver_ws_broadcast,
|
||||
):
|
||||
"""broadcast_event() end-to-end: DAG_STATUS event reaches transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "broadcast_event DAG test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
from reflector.db.transcripts import TranscriptEvent
|
||||
from reflector.hatchet.broadcast import broadcast_event
|
||||
from reflector.logger import logger
|
||||
|
||||
log = logger.bind(transcript_id=transcript_id)
|
||||
event = TranscriptEvent(
|
||||
event="DAG_STATUS",
|
||||
data={
|
||||
"workflow_run_id": "test-run-456",
|
||||
"tasks": [
|
||||
{
|
||||
"name": "get_recording",
|
||||
"status": "running",
|
||||
"started_at": None,
|
||||
"finished_at": None,
|
||||
"duration_seconds": None,
|
||||
"parents": [],
|
||||
"error": None,
|
||||
"children_total": None,
|
||||
"children_completed": None,
|
||||
"progress_pct": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
await broadcast_event(transcript_id, event, logger=log)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "DAG_STATUS"
|
||||
assert msg["data"]["tasks"][0]["name"] == "get_recording"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: Multiple rapid events arrive in order
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_events_arrive_in_order(appserver_ws_broadcast):
|
||||
"""Multiple STATUS then DAG_STATUS events arrive in correct order."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "ordering test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
import reflector.ws_manager as ws_mod
|
||||
|
||||
ws_manager = ws_mod.get_ws_manager()
|
||||
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={"event": "STATUS", "data": {"value": "processing"}},
|
||||
)
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={
|
||||
"event": "DAG_STATUS",
|
||||
"data": {"workflow_run_id": "r1", "tasks": []},
|
||||
},
|
||||
)
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "a", "status": "running"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={"event": "STATUS", "data": {"value": "ended"}},
|
||||
)
|
||||
|
||||
msgs = []
|
||||
for _ in range(4):
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
msgs.append(msg)
|
||||
|
||||
assert msgs[0]["event"] == "STATUS"
|
||||
assert msgs[0]["data"]["value"] == "processing"
|
||||
assert msgs[1]["event"] == "DAG_STATUS"
|
||||
assert msgs[1]["data"]["tasks"] == []
|
||||
assert msgs[2]["event"] == "DAG_STATUS"
|
||||
assert len(msgs[2]["data"]["tasks"]) == 1
|
||||
assert msgs[3]["event"] == "STATUS"
|
||||
assert msgs[3]["data"]["value"] == "ended"
|
||||
Reference in New Issue
Block a user