diff --git a/server/reflector/app.py b/server/reflector/app.py index 609474a2..8c8724a6 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -26,6 +26,7 @@ from reflector.views.transcripts_upload import router as transcripts_upload_rout from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.user import router as user_router +from reflector.views.user_websocket import router as user_ws_router from reflector.views.whereby import router as whereby_router from reflector.views.zulip import router as zulip_router @@ -90,6 +91,7 @@ app.include_router(transcripts_websocket_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(transcripts_process_router, prefix="/v1") app.include_router(user_router, prefix="/v1") +app.include_router(user_ws_router, prefix="/v1") app.include_router(zulip_router, prefix="/v1") app.include_router(whereby_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index ce9d000e..bbf23e7b 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -131,7 +131,7 @@ class PipelineMainFile(PipelineMainBase): self.logger.info("File pipeline complete") - await transcripts_controller.set_status(transcript.id, "ended") + await self.set_status(transcript.id, "ended") async def extract_and_write_audio( self, file_path: Path, transcript: Transcript diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 64904952..f6fe6a83 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -85,6 +85,20 @@ def broadcast_to_sockets(func): message=resp.model_dump(mode="json"), ) + transcript = await transcripts_controller.get_by_id(self.transcript_id) + if transcript and transcript.user_id: + # Emit only relevant events to the user room to avoid noisy updates. + # Allowed: STATUS, FINAL_TITLE, DURATION. All are prefixed with TRANSCRIPT_ + allowed_user_events = {"STATUS", "FINAL_TITLE", "DURATION"} + if resp.event in allowed_user_events: + await self.ws_manager.send_json( + room_id=f"user:{transcript.user_id}", + message={ + "event": f"TRANSCRIPT_{resp.event}", + "data": {"id": self.transcript_id, **resp.data}, + }, + ) + return wrapper diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 839c6cdb..04d27e1a 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -32,6 +32,7 @@ from reflector.db.transcripts import ( from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Word from reflector.settings import settings +from reflector.ws_manager import get_ws_manager from reflector.zulip import ( InvalidMessageError, get_zulip_message, @@ -211,7 +212,7 @@ async def transcripts_create( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): user_id = user["sub"] if user else None - return await transcripts_controller.add( + transcript = await transcripts_controller.add( info.name, source_kind=info.source_kind or SourceKind.LIVE, source_language=info.source_language, @@ -219,6 +220,14 @@ async def transcripts_create( user_id=user_id, ) + if user_id: + await get_ws_manager().send_json( + room_id=f"user:{user_id}", + message={"event": "TRANSCRIPT_CREATED", "data": {"id": transcript.id}}, + ) + + return transcript + # ============================================================== # Single transcript @@ -368,6 +377,10 @@ async def transcript_delete( raise HTTPException(status_code=403, detail="Not authorized") await transcripts_controller.remove_by_id(transcript.id, user_id=user_id) + await get_ws_manager().send_json( + room_id=f"user:{user_id}", + message={"event": "TRANSCRIPT_DELETED", "data": {"id": transcript.id}}, + ) return DeletionStatus(status="ok") diff --git a/server/reflector/views/user_websocket.py b/server/reflector/views/user_websocket.py new file mode 100644 index 00000000..26d3c8ac --- /dev/null +++ b/server/reflector/views/user_websocket.py @@ -0,0 +1,53 @@ +from typing import Optional + +from fastapi import APIRouter, WebSocket + +from reflector.auth.auth_jwt import JWTAuth # type: ignore +from reflector.ws_manager import get_ws_manager + +router = APIRouter() + +# Close code for unauthorized WebSocket connections +UNAUTHORISED = 4401 + + +@router.websocket("/events") +async def user_events_websocket(websocket: WebSocket): + # Browser can't send Authorization header for WS; use subprotocol: ["bearer", token] + raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or "" + parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()] + token: Optional[str] = None + negotiated_subprotocol: Optional[str] = None + if len(parts) >= 2 and parts[0].lower() == "bearer": + negotiated_subprotocol = "bearer" + token = parts[1] + + user_id: Optional[str] = None + if not token: + await websocket.close(code=UNAUTHORISED) + return + + try: + payload = JWTAuth().verify_token(token) + user_id = payload.get("sub") + except Exception: + await websocket.close(code=UNAUTHORISED) + return + + if not user_id: + await websocket.close(code=UNAUTHORISED) + return + + room_id = f"user:{user_id}" + ws_manager = get_ws_manager() + + await ws_manager.add_user_to_room( + room_id, websocket, subprotocol=negotiated_subprotocol + ) + + try: + while True: + await websocket.receive() + finally: + if room_id: + await ws_manager.remove_user_from_room(room_id, websocket) diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index 07790e09..a1f620c4 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -65,8 +65,13 @@ class WebsocketManager: self.tasks: dict = {} self.pubsub_client = pubsub_client - async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: - await websocket.accept() + async def add_user_to_room( + self, room_id: str, websocket: WebSocket, subprotocol: str | None = None + ) -> None: + if subprotocol: + await websocket.accept(subprotocol=subprotocol) + else: + await websocket.accept() if room_id in self.rooms: self.rooms[room_id].append(websocket) diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 8271d1ad..a70604ae 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -398,6 +398,10 @@ async def ws_manager_in_memory(monkeypatch): monkeypatch.setattr( "reflector.views.transcripts_websocket.get_ws_manager", _get_ws_manager ) + monkeypatch.setattr( + "reflector.views.user_websocket.get_ws_manager", _get_ws_manager + ) + monkeypatch.setattr("reflector.views.transcripts.get_ws_manager", _get_ws_manager) # Websocket auth: avoid OAuth2 on websocket dependencies; allow anonymous import reflector.auth as auth diff --git a/server/tests/test_user_websocket_auth.py b/server/tests/test_user_websocket_auth.py new file mode 100644 index 00000000..be1a2816 --- /dev/null +++ b/server/tests/test_user_websocket_auth.py @@ -0,0 +1,156 @@ +import asyncio +import threading +import time + +import pytest +from httpx_ws import aconnect_ws +from uvicorn import Config, Server + + +@pytest.fixture +def appserver_ws_user(setup_database): + from reflector.app import app + from reflector.db import get_database + + host = "127.0.0.1" + port = 1257 + server_started = threading.Event() + server_exception = None + server_instance = None + + def run_server(): + nonlocal server_exception, server_instance + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + config = Config(app=app, host=host, port=port, loop=loop) + server_instance = Server(config) + + async def start_server(): + database = get_database() + await database.connect() + try: + await server_instance.serve() + finally: + await database.disconnect() + + server_started.set() + loop.run_until_complete(start_server()) + except Exception as e: + server_exception = e + server_started.set() + finally: + loop.close() + + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + server_started.wait(timeout=30) + if server_exception: + raise server_exception + + time.sleep(0.5) + + yield host, port + + if server_instance: + server_instance.should_exit = True + server_thread.join(timeout=30) + + +@pytest.fixture(autouse=True) +def patch_jwt_verification(monkeypatch): + """Patch JWT verification to accept HS256 tokens signed with SECRET_KEY for tests.""" + from jose import jwt + + from reflector.settings import settings + + def _verify_token(self, token: str): + # Do not validate audience in tests + return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) # type: ignore[arg-type] + + monkeypatch.setattr( + "reflector.auth.auth_jwt.JWTAuth.verify_token", _verify_token, raising=True + ) + + +def _make_dummy_jwt(sub: str = "user123") -> str: + # Create a short HS256 JWT using the app secret to pass verification in tests + from datetime import datetime, timedelta, timezone + + from jose import jwt + + from reflector.settings import settings + + payload = { + "sub": sub, + "email": f"{sub}@example.com", + "exp": datetime.now(timezone.utc) + timedelta(minutes=5), + } + # Note: production uses RS256 public key verification; tests can sign with SECRET_KEY + return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256") + + +@pytest.mark.asyncio +async def test_user_ws_rejects_missing_subprotocol(appserver_ws_user): + host, port = appserver_ws_user + base_ws = f"http://{host}:{port}/v1/events" + # No subprotocol/header with token + with pytest.raises(Exception): + async with aconnect_ws(base_ws) as ws: # type: ignore + # Should close during handshake; if not, close explicitly + await ws.close() + + +@pytest.mark.asyncio +async def test_user_ws_rejects_invalid_token(appserver_ws_user): + host, port = appserver_ws_user + base_ws = f"http://{host}:{port}/v1/events" + + # Send wrong token via WebSocket subprotocols + protocols = ["bearer", "totally-invalid-token"] + with pytest.raises(Exception): + async with aconnect_ws(base_ws, subprotocols=protocols) as ws: # type: ignore + await ws.close() + + +@pytest.mark.asyncio +async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user): + host, port = appserver_ws_user + base_ws = f"http://{host}:{port}/v1/events" + + token = _make_dummy_jwt("user-abc") + subprotocols = ["bearer", token] + + # Connect and then trigger an event via HTTP create + async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws: + # Emit an event to the user's room via a standard HTTP action + from httpx import AsyncClient + + from reflector.app import app + from reflector.auth import current_user, current_user_optional + + # Override auth dependencies so HTTP request is performed as the same user + app.dependency_overrides[current_user] = lambda: { + "sub": "user-abc", + "email": "user-abc@example.com", + } + app.dependency_overrides[current_user_optional] = lambda: { + "sub": "user-abc", + "email": "user-abc@example.com", + } + + async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac: + # Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room + resp = await ac.post("/transcripts", json={"name": "WS Test"}) + assert resp.status_code == 200 + + # Receive the published event + msg = await ws.receive_json() + assert msg["event"] == "TRANSCRIPT_CREATED" + assert "id" in msg["data"] + + # Clean overrides + del app.dependency_overrides[current_user] + del app.dependency_overrides[current_user_optional] diff --git a/www/app/lib/UserEventsProvider.tsx b/www/app/lib/UserEventsProvider.tsx new file mode 100644 index 00000000..89ec5a11 --- /dev/null +++ b/www/app/lib/UserEventsProvider.tsx @@ -0,0 +1,180 @@ +"use client"; + +import React, { useEffect, useRef } from "react"; +import { useQueryClient } from "@tanstack/react-query"; +import { WEBSOCKET_URL } from "./apiClient"; +import { useAuth } from "./AuthProvider"; +import { z } from "zod"; +import { invalidateTranscriptLists, TRANSCRIPT_SEARCH_URL } from "./apiHooks"; + +const UserEvent = z.object({ + event: z.string(), +}); + +type UserEvent = z.TypeOf; + +class UserEventsStore { + private socket: WebSocket | null = null; + private listeners: Set<(event: MessageEvent) => void> = new Set(); + private closeTimeoutId: number | null = null; + private isConnecting = false; + + ensureConnection(url: string, subprotocols?: string[]) { + if (typeof window === "undefined") return; + if (this.closeTimeoutId !== null) { + clearTimeout(this.closeTimeoutId); + this.closeTimeoutId = null; + } + if (this.isConnecting) return; + if ( + this.socket && + (this.socket.readyState === WebSocket.OPEN || + this.socket.readyState === WebSocket.CONNECTING) + ) { + return; + } + this.isConnecting = true; + const ws = new WebSocket(url, subprotocols || []); + this.socket = ws; + ws.onmessage = (event: MessageEvent) => { + this.listeners.forEach((listener) => { + try { + listener(event); + } catch (err) { + console.error("UserEvents listener error", err); + } + }); + }; + ws.onopen = () => { + if (this.socket === ws) this.isConnecting = false; + }; + ws.onclose = () => { + if (this.socket === ws) { + this.socket = null; + this.isConnecting = false; + } + }; + ws.onerror = () => { + if (this.socket === ws) this.isConnecting = false; + }; + } + + subscribe(listener: (event: MessageEvent) => void): () => void { + this.listeners.add(listener); + if (this.closeTimeoutId !== null) { + clearTimeout(this.closeTimeoutId); + this.closeTimeoutId = null; + } + return () => { + this.listeners.delete(listener); + if (this.listeners.size === 0) { + this.closeTimeoutId = window.setTimeout(() => { + if (this.socket) { + try { + this.socket.close(); + } catch (err) { + console.warn("Error closing user events socket", err); + } + } + this.socket = null; + this.closeTimeoutId = null; + }, 1000); + } + }; + } +} + +const sharedStore = new UserEventsStore(); + +export function UserEventsProvider({ + children, +}: { + children: React.ReactNode; +}) { + const auth = useAuth(); + const queryClient = useQueryClient(); + const tokenRef = useRef(null); + const detachRef = useRef<(() => void) | null>(null); + + useEffect(() => { + // Only tear down when the user is truly unauthenticated + if (auth.status === "unauthenticated") { + if (detachRef.current) { + try { + detachRef.current(); + } catch (err) { + console.warn("Error detaching UserEvents listener", err); + } + detachRef.current = null; + } + tokenRef.current = null; + return; + } + + // During loading/refreshing, keep the existing connection intact + if (auth.status !== "authenticated") { + return; + } + + // Authenticated: pin the initial token for the lifetime of this WS connection + if (!tokenRef.current && auth.accessToken) { + tokenRef.current = auth.accessToken; + } + const pinnedToken = tokenRef.current; + const url = `${WEBSOCKET_URL}/v1/events`; + + // Ensure a single shared connection + sharedStore.ensureConnection( + url, + pinnedToken ? ["bearer", pinnedToken] : undefined, + ); + + // Subscribe once; avoid re-subscribing during transient status changes + if (!detachRef.current) { + const onMessage = (event: MessageEvent) => { + try { + const msg = UserEvent.parse(JSON.parse(event.data)); + const eventName = msg.event; + + const invalidateList = () => invalidateTranscriptLists(queryClient); + + switch (eventName) { + case "TRANSCRIPT_CREATED": + case "TRANSCRIPT_DELETED": + case "TRANSCRIPT_STATUS": + case "TRANSCRIPT_FINAL_TITLE": + case "TRANSCRIPT_DURATION": + invalidateList().then(() => {}); + break; + + default: + // Ignore other content events for list updates + break; + } + } catch (err) { + console.warn("Invalid user event message", event.data); + } + }; + + const unsubscribe = sharedStore.subscribe(onMessage); + detachRef.current = unsubscribe; + } + }, [auth.status, queryClient]); + + // On unmount, detach the listener and clear the pinned token + useEffect(() => { + return () => { + if (detachRef.current) { + try { + detachRef.current(); + } catch (err) { + console.warn("Error detaching UserEvents listener on unmount", err); + } + detachRef.current = null; + } + tokenRef.current = null; + }; + }, []); + + return <>{children}; +} diff --git a/www/app/lib/apiHooks.ts b/www/app/lib/apiHooks.ts index c5b4f9b9..726e5441 100644 --- a/www/app/lib/apiHooks.ts +++ b/www/app/lib/apiHooks.ts @@ -2,7 +2,7 @@ import { $api } from "./apiClient"; import { useError } from "../(errors)/errorContext"; -import { useQueryClient } from "@tanstack/react-query"; +import { QueryClient, useQueryClient } from "@tanstack/react-query"; import type { components } from "../reflector-api"; import { useAuth } from "./AuthProvider"; @@ -40,6 +40,13 @@ export function useRoomsList(page: number = 1) { type SourceKind = components["schemas"]["SourceKind"]; +export const TRANSCRIPT_SEARCH_URL = "/v1/transcripts/search" as const; + +export const invalidateTranscriptLists = (queryClient: QueryClient) => + queryClient.invalidateQueries({ + queryKey: ["get", TRANSCRIPT_SEARCH_URL], + }); + export function useTranscriptsSearch( q: string = "", options: { @@ -51,7 +58,7 @@ export function useTranscriptsSearch( ) { return $api.useQuery( "get", - "/v1/transcripts/search", + TRANSCRIPT_SEARCH_URL, { params: { query: { @@ -76,7 +83,7 @@ export function useTranscriptDelete() { return $api.useMutation("delete", "/v1/transcripts/{transcript_id}", { onSuccess: () => { return queryClient.invalidateQueries({ - queryKey: ["get", "/v1/transcripts/search"], + queryKey: ["get", TRANSCRIPT_SEARCH_URL], }); }, onError: (error) => { @@ -613,7 +620,7 @@ export function useTranscriptCreate() { return $api.useMutation("post", "/v1/transcripts", { onSuccess: () => { return queryClient.invalidateQueries({ - queryKey: ["get", "/v1/transcripts/search"], + queryKey: ["get", TRANSCRIPT_SEARCH_URL], }); }, onError: (error) => { diff --git a/www/app/providers.tsx b/www/app/providers.tsx index 37b37a0e..6e689812 100644 --- a/www/app/providers.tsx +++ b/www/app/providers.tsx @@ -11,6 +11,7 @@ import { queryClient } from "./lib/queryClient"; import { AuthProvider } from "./lib/AuthProvider"; import { SessionProvider as SessionProviderNextAuth } from "next-auth/react"; import { RecordingConsentProvider } from "./recordingConsentContext"; +import { UserEventsProvider } from "./lib/UserEventsProvider"; const WherebyProvider = dynamic( () => @@ -28,10 +29,12 @@ export function Providers({ children }: { children: React.ReactNode }) { - - {children} - - + + + {children} + + +