fix: update transcript list on reprocess (#676)

* Update transcript list on reprocess

* Fix transcript create

* Fix multiple sockets issue

* Pass token in sec websocket protocol

* userEvent parse example

* transcript list invalidation non-abstraction

* Emit only relevant events to the user room

* Add ws close code const

* Refactor user websocket endpoint

* Refactor user events provider

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
2025-10-07 19:11:30 +02:00
committed by GitHub
parent eef6dc3903
commit 9a71af145e
11 changed files with 449 additions and 12 deletions

View File

@@ -26,6 +26,7 @@ from reflector.views.transcripts_upload import router as transcripts_upload_rout
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router from reflector.views.user import router as user_router
from reflector.views.user_websocket import router as user_ws_router
from reflector.views.whereby import router as whereby_router from reflector.views.whereby import router as whereby_router
from reflector.views.zulip import router as zulip_router from reflector.views.zulip import router as zulip_router
@@ -90,6 +91,7 @@ app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(transcripts_process_router, prefix="/v1") app.include_router(transcripts_process_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")
app.include_router(user_ws_router, prefix="/v1")
app.include_router(zulip_router, prefix="/v1") app.include_router(zulip_router, prefix="/v1")
app.include_router(whereby_router, prefix="/v1") app.include_router(whereby_router, prefix="/v1")
add_pagination(app) add_pagination(app)

View File

@@ -131,7 +131,7 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete") self.logger.info("File pipeline complete")
await transcripts_controller.set_status(transcript.id, "ended") await self.set_status(transcript.id, "ended")
async def extract_and_write_audio( async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript self, file_path: Path, transcript: Transcript

View File

@@ -85,6 +85,20 @@ def broadcast_to_sockets(func):
message=resp.model_dump(mode="json"), message=resp.model_dump(mode="json"),
) )
transcript = await transcripts_controller.get_by_id(self.transcript_id)
if transcript and transcript.user_id:
# Emit only relevant events to the user room to avoid noisy updates.
# Allowed: STATUS, FINAL_TITLE, DURATION. All are prefixed with TRANSCRIPT_
allowed_user_events = {"STATUS", "FINAL_TITLE", "DURATION"}
if resp.event in allowed_user_events:
await self.ws_manager.send_json(
room_id=f"user:{transcript.user_id}",
message={
"event": f"TRANSCRIPT_{resp.event}",
"data": {"id": self.transcript_id, **resp.data},
},
)
return wrapper return wrapper

View File

@@ -32,6 +32,7 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word from reflector.processors.types import Word
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from reflector.zulip import ( from reflector.zulip import (
InvalidMessageError, InvalidMessageError,
get_zulip_message, get_zulip_message,
@@ -211,7 +212,7 @@ async def transcripts_create(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await transcripts_controller.add( transcript = await transcripts_controller.add(
info.name, info.name,
source_kind=info.source_kind or SourceKind.LIVE, source_kind=info.source_kind or SourceKind.LIVE,
source_language=info.source_language, source_language=info.source_language,
@@ -219,6 +220,14 @@ async def transcripts_create(
user_id=user_id, user_id=user_id,
) )
if user_id:
await get_ws_manager().send_json(
room_id=f"user:{user_id}",
message={"event": "TRANSCRIPT_CREATED", "data": {"id": transcript.id}},
)
return transcript
# ============================================================== # ==============================================================
# Single transcript # Single transcript
@@ -368,6 +377,10 @@ async def transcript_delete(
raise HTTPException(status_code=403, detail="Not authorized") raise HTTPException(status_code=403, detail="Not authorized")
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id) await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
await get_ws_manager().send_json(
room_id=f"user:{user_id}",
message={"event": "TRANSCRIPT_DELETED", "data": {"id": transcript.id}},
)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")

View File

@@ -0,0 +1,53 @@
from typing import Optional
from fastapi import APIRouter, WebSocket
from reflector.auth.auth_jwt import JWTAuth # type: ignore
from reflector.ws_manager import get_ws_manager
router = APIRouter()
# Close code for unauthorized WebSocket connections
UNAUTHORISED = 4401
@router.websocket("/events")
async def user_events_websocket(websocket: WebSocket):
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
token: Optional[str] = None
negotiated_subprotocol: Optional[str] = None
if len(parts) >= 2 and parts[0].lower() == "bearer":
negotiated_subprotocol = "bearer"
token = parts[1]
user_id: Optional[str] = None
if not token:
await websocket.close(code=UNAUTHORISED)
return
try:
payload = JWTAuth().verify_token(token)
user_id = payload.get("sub")
except Exception:
await websocket.close(code=UNAUTHORISED)
return
if not user_id:
await websocket.close(code=UNAUTHORISED)
return
room_id = f"user:{user_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(
room_id, websocket, subprotocol=negotiated_subprotocol
)
try:
while True:
await websocket.receive()
finally:
if room_id:
await ws_manager.remove_user_from_room(room_id, websocket)

View File

@@ -65,8 +65,13 @@ class WebsocketManager:
self.tasks: dict = {} self.tasks: dict = {}
self.pubsub_client = pubsub_client self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: async def add_user_to_room(
await websocket.accept() self, room_id: str, websocket: WebSocket, subprotocol: str | None = None
) -> None:
if subprotocol:
await websocket.accept(subprotocol=subprotocol)
else:
await websocket.accept()
if room_id in self.rooms: if room_id in self.rooms:
self.rooms[room_id].append(websocket) self.rooms[room_id].append(websocket)

View File

@@ -398,6 +398,10 @@ async def ws_manager_in_memory(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"reflector.views.transcripts_websocket.get_ws_manager", _get_ws_manager "reflector.views.transcripts_websocket.get_ws_manager", _get_ws_manager
) )
monkeypatch.setattr(
"reflector.views.user_websocket.get_ws_manager", _get_ws_manager
)
monkeypatch.setattr("reflector.views.transcripts.get_ws_manager", _get_ws_manager)
# Websocket auth: avoid OAuth2 on websocket dependencies; allow anonymous # Websocket auth: avoid OAuth2 on websocket dependencies; allow anonymous
import reflector.auth as auth import reflector.auth as auth

View File

@@ -0,0 +1,156 @@
import asyncio
import threading
import time
import pytest
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
@pytest.fixture
def appserver_ws_user(setup_database):
from reflector.app import app
from reflector.db import get_database
host = "127.0.0.1"
port = 1257
server_started = threading.Event()
server_exception = None
server_instance = None
def run_server():
nonlocal server_exception, server_instance
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
server_started.wait(timeout=30)
if server_exception:
raise server_exception
time.sleep(0.5)
yield host, port
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
@pytest.fixture(autouse=True)
def patch_jwt_verification(monkeypatch):
"""Patch JWT verification to accept HS256 tokens signed with SECRET_KEY for tests."""
from jose import jwt
from reflector.settings import settings
def _verify_token(self, token: str):
# Do not validate audience in tests
return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) # type: ignore[arg-type]
monkeypatch.setattr(
"reflector.auth.auth_jwt.JWTAuth.verify_token", _verify_token, raising=True
)
def _make_dummy_jwt(sub: str = "user123") -> str:
# Create a short HS256 JWT using the app secret to pass verification in tests
from datetime import datetime, timedelta, timezone
from jose import jwt
from reflector.settings import settings
payload = {
"sub": sub,
"email": f"{sub}@example.com",
"exp": datetime.now(timezone.utc) + timedelta(minutes=5),
}
# Note: production uses RS256 public key verification; tests can sign with SECRET_KEY
return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
@pytest.mark.asyncio
async def test_user_ws_rejects_missing_subprotocol(appserver_ws_user):
host, port = appserver_ws_user
base_ws = f"http://{host}:{port}/v1/events"
# No subprotocol/header with token
with pytest.raises(Exception):
async with aconnect_ws(base_ws) as ws: # type: ignore
# Should close during handshake; if not, close explicitly
await ws.close()
@pytest.mark.asyncio
async def test_user_ws_rejects_invalid_token(appserver_ws_user):
host, port = appserver_ws_user
base_ws = f"http://{host}:{port}/v1/events"
# Send wrong token via WebSocket subprotocols
protocols = ["bearer", "totally-invalid-token"]
with pytest.raises(Exception):
async with aconnect_ws(base_ws, subprotocols=protocols) as ws: # type: ignore
await ws.close()
@pytest.mark.asyncio
async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user):
host, port = appserver_ws_user
base_ws = f"http://{host}:{port}/v1/events"
token = _make_dummy_jwt("user-abc")
subprotocols = ["bearer", token]
# Connect and then trigger an event via HTTP create
async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws:
# Emit an event to the user's room via a standard HTTP action
from httpx import AsyncClient
from reflector.app import app
from reflector.auth import current_user, current_user_optional
# Override auth dependencies so HTTP request is performed as the same user
app.dependency_overrides[current_user] = lambda: {
"sub": "user-abc",
"email": "user-abc@example.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "user-abc",
"email": "user-abc@example.com",
}
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
resp = await ac.post("/transcripts", json={"name": "WS Test"})
assert resp.status_code == 200
# Receive the published event
msg = await ws.receive_json()
assert msg["event"] == "TRANSCRIPT_CREATED"
assert "id" in msg["data"]
# Clean overrides
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]

View File

@@ -0,0 +1,180 @@
"use client";
import React, { useEffect, useRef } from "react";
import { useQueryClient } from "@tanstack/react-query";
import { WEBSOCKET_URL } from "./apiClient";
import { useAuth } from "./AuthProvider";
import { z } from "zod";
import { invalidateTranscriptLists, TRANSCRIPT_SEARCH_URL } from "./apiHooks";
const UserEvent = z.object({
event: z.string(),
});
type UserEvent = z.TypeOf<typeof UserEvent>;
class UserEventsStore {
private socket: WebSocket | null = null;
private listeners: Set<(event: MessageEvent) => void> = new Set();
private closeTimeoutId: number | null = null;
private isConnecting = false;
ensureConnection(url: string, subprotocols?: string[]) {
if (typeof window === "undefined") return;
if (this.closeTimeoutId !== null) {
clearTimeout(this.closeTimeoutId);
this.closeTimeoutId = null;
}
if (this.isConnecting) return;
if (
this.socket &&
(this.socket.readyState === WebSocket.OPEN ||
this.socket.readyState === WebSocket.CONNECTING)
) {
return;
}
this.isConnecting = true;
const ws = new WebSocket(url, subprotocols || []);
this.socket = ws;
ws.onmessage = (event: MessageEvent) => {
this.listeners.forEach((listener) => {
try {
listener(event);
} catch (err) {
console.error("UserEvents listener error", err);
}
});
};
ws.onopen = () => {
if (this.socket === ws) this.isConnecting = false;
};
ws.onclose = () => {
if (this.socket === ws) {
this.socket = null;
this.isConnecting = false;
}
};
ws.onerror = () => {
if (this.socket === ws) this.isConnecting = false;
};
}
subscribe(listener: (event: MessageEvent) => void): () => void {
this.listeners.add(listener);
if (this.closeTimeoutId !== null) {
clearTimeout(this.closeTimeoutId);
this.closeTimeoutId = null;
}
return () => {
this.listeners.delete(listener);
if (this.listeners.size === 0) {
this.closeTimeoutId = window.setTimeout(() => {
if (this.socket) {
try {
this.socket.close();
} catch (err) {
console.warn("Error closing user events socket", err);
}
}
this.socket = null;
this.closeTimeoutId = null;
}, 1000);
}
};
}
}
const sharedStore = new UserEventsStore();
export function UserEventsProvider({
children,
}: {
children: React.ReactNode;
}) {
const auth = useAuth();
const queryClient = useQueryClient();
const tokenRef = useRef<string | null>(null);
const detachRef = useRef<(() => void) | null>(null);
useEffect(() => {
// Only tear down when the user is truly unauthenticated
if (auth.status === "unauthenticated") {
if (detachRef.current) {
try {
detachRef.current();
} catch (err) {
console.warn("Error detaching UserEvents listener", err);
}
detachRef.current = null;
}
tokenRef.current = null;
return;
}
// During loading/refreshing, keep the existing connection intact
if (auth.status !== "authenticated") {
return;
}
// Authenticated: pin the initial token for the lifetime of this WS connection
if (!tokenRef.current && auth.accessToken) {
tokenRef.current = auth.accessToken;
}
const pinnedToken = tokenRef.current;
const url = `${WEBSOCKET_URL}/v1/events`;
// Ensure a single shared connection
sharedStore.ensureConnection(
url,
pinnedToken ? ["bearer", pinnedToken] : undefined,
);
// Subscribe once; avoid re-subscribing during transient status changes
if (!detachRef.current) {
const onMessage = (event: MessageEvent) => {
try {
const msg = UserEvent.parse(JSON.parse(event.data));
const eventName = msg.event;
const invalidateList = () => invalidateTranscriptLists(queryClient);
switch (eventName) {
case "TRANSCRIPT_CREATED":
case "TRANSCRIPT_DELETED":
case "TRANSCRIPT_STATUS":
case "TRANSCRIPT_FINAL_TITLE":
case "TRANSCRIPT_DURATION":
invalidateList().then(() => {});
break;
default:
// Ignore other content events for list updates
break;
}
} catch (err) {
console.warn("Invalid user event message", event.data);
}
};
const unsubscribe = sharedStore.subscribe(onMessage);
detachRef.current = unsubscribe;
}
}, [auth.status, queryClient]);
// On unmount, detach the listener and clear the pinned token
useEffect(() => {
return () => {
if (detachRef.current) {
try {
detachRef.current();
} catch (err) {
console.warn("Error detaching UserEvents listener on unmount", err);
}
detachRef.current = null;
}
tokenRef.current = null;
};
}, []);
return <>{children}</>;
}

View File

@@ -2,7 +2,7 @@
import { $api } from "./apiClient"; import { $api } from "./apiClient";
import { useError } from "../(errors)/errorContext"; import { useError } from "../(errors)/errorContext";
import { useQueryClient } from "@tanstack/react-query"; import { QueryClient, useQueryClient } from "@tanstack/react-query";
import type { components } from "../reflector-api"; import type { components } from "../reflector-api";
import { useAuth } from "./AuthProvider"; import { useAuth } from "./AuthProvider";
@@ -40,6 +40,13 @@ export function useRoomsList(page: number = 1) {
type SourceKind = components["schemas"]["SourceKind"]; type SourceKind = components["schemas"]["SourceKind"];
export const TRANSCRIPT_SEARCH_URL = "/v1/transcripts/search" as const;
export const invalidateTranscriptLists = (queryClient: QueryClient) =>
queryClient.invalidateQueries({
queryKey: ["get", TRANSCRIPT_SEARCH_URL],
});
export function useTranscriptsSearch( export function useTranscriptsSearch(
q: string = "", q: string = "",
options: { options: {
@@ -51,7 +58,7 @@ export function useTranscriptsSearch(
) { ) {
return $api.useQuery( return $api.useQuery(
"get", "get",
"/v1/transcripts/search", TRANSCRIPT_SEARCH_URL,
{ {
params: { params: {
query: { query: {
@@ -76,7 +83,7 @@ export function useTranscriptDelete() {
return $api.useMutation("delete", "/v1/transcripts/{transcript_id}", { return $api.useMutation("delete", "/v1/transcripts/{transcript_id}", {
onSuccess: () => { onSuccess: () => {
return queryClient.invalidateQueries({ return queryClient.invalidateQueries({
queryKey: ["get", "/v1/transcripts/search"], queryKey: ["get", TRANSCRIPT_SEARCH_URL],
}); });
}, },
onError: (error) => { onError: (error) => {
@@ -613,7 +620,7 @@ export function useTranscriptCreate() {
return $api.useMutation("post", "/v1/transcripts", { return $api.useMutation("post", "/v1/transcripts", {
onSuccess: () => { onSuccess: () => {
return queryClient.invalidateQueries({ return queryClient.invalidateQueries({
queryKey: ["get", "/v1/transcripts/search"], queryKey: ["get", TRANSCRIPT_SEARCH_URL],
}); });
}, },
onError: (error) => { onError: (error) => {

View File

@@ -11,6 +11,7 @@ import { queryClient } from "./lib/queryClient";
import { AuthProvider } from "./lib/AuthProvider"; import { AuthProvider } from "./lib/AuthProvider";
import { SessionProvider as SessionProviderNextAuth } from "next-auth/react"; import { SessionProvider as SessionProviderNextAuth } from "next-auth/react";
import { RecordingConsentProvider } from "./recordingConsentContext"; import { RecordingConsentProvider } from "./recordingConsentContext";
import { UserEventsProvider } from "./lib/UserEventsProvider";
const WherebyProvider = dynamic( const WherebyProvider = dynamic(
() => () =>
@@ -28,10 +29,12 @@ export function Providers({ children }: { children: React.ReactNode }) {
<AuthProvider> <AuthProvider>
<ChakraProvider value={system}> <ChakraProvider value={system}>
<RecordingConsentProvider> <RecordingConsentProvider>
<WherebyProvider> <UserEventsProvider>
{children} <WherebyProvider>
<Toaster /> {children}
</WherebyProvider> <Toaster />
</WherebyProvider>
</UserEventsProvider>
</RecordingConsentProvider> </RecordingConsentProvider>
</ChakraProvider> </ChakraProvider>
</AuthProvider> </AuthProvider>