This commit is contained in:
Igor Loskutov
2026-01-13 12:44:43 -05:00
parent 68df825734
commit 3652de9fca
5 changed files with 156 additions and 69 deletions

View File

@@ -7,13 +7,14 @@ WebSocket endpoint for bidirectional chat with LLM about transcript content.
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
import reflector.auth as auth
from reflector.auth.auth_jwt import JWTAuth
from reflector.db.recordings import recordings_controller
from reflector.db.transcripts import transcripts_controller
from reflector.db.users import user_controller
from reflector.llm import LLM
from reflector.settings import settings
from reflector.utils.transcript_formats import topics_to_webvtt_named
@@ -33,19 +34,42 @@ async def _get_is_multitrack(transcript) -> bool:
async def transcript_chat_websocket(
transcript_id: str,
websocket: WebSocket,
user: Optional[auth.UserInfo] = Depends(auth.current_user_optional),
):
"""WebSocket endpoint for chatting with LLM about transcript content."""
# 1. Auth check
user_id = user["sub"] if user else None
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
token: Optional[str] = None
negotiated_subprotocol: Optional[str] = None
if len(parts) >= 2 and parts[0].lower() == "bearer":
negotiated_subprotocol = "bearer"
token = parts[1]
user_id: Optional[str] = None
if token:
try:
payload = JWTAuth().verify_token(token)
authentik_uid = payload.get("sub")
if authentik_uid:
user = await user_controller.get_by_authentik_uid(authentik_uid)
if user:
user_id = user.id
except Exception:
# Auth failed - continue as anonymous
pass
# Get transcript (respects user_id for private transcripts)
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
return
# 2. Accept connection
await websocket.accept()
# 2. Accept connection (with negotiated subprotocol if present)
await websocket.accept(subprotocol=negotiated_subprotocol)
# 3. Generate WebVTT context
is_multitrack = await _get_is_multitrack(transcript)
@@ -90,7 +114,8 @@ Answer questions about content, speakers, timeline. Include timestamps when rele
# Stream LLM response
assistant_msg = ""
async for chunk in Settings.llm.astream_chat(conversation_history):
chat_stream = await Settings.llm.astream_chat(conversation_history)
async for chunk in chat_stream:
token = chunk.delta or ""
if token:
await websocket.send_json({"type": "token", "text": token})