mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 18:06:48 +00:00
md
This commit is contained in:
@@ -7,13 +7,14 @@ WebSocket endpoint for bidirectional chat with LLM about transcript content.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.auth.auth_jwt import JWTAuth
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.llm import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.transcript_formats import topics_to_webvtt_named
|
||||
@@ -33,19 +34,42 @@ async def _get_is_multitrack(transcript) -> bool:
|
||||
async def transcript_chat_websocket(
|
||||
transcript_id: str,
|
||||
websocket: WebSocket,
|
||||
user: Optional[auth.UserInfo] = Depends(auth.current_user_optional),
|
||||
):
|
||||
"""WebSocket endpoint for chatting with LLM about transcript content."""
|
||||
# 1. Auth check
|
||||
user_id = user["sub"] if user else None
|
||||
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
|
||||
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
|
||||
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
|
||||
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
|
||||
token: Optional[str] = None
|
||||
negotiated_subprotocol: Optional[str] = None
|
||||
if len(parts) >= 2 and parts[0].lower() == "bearer":
|
||||
negotiated_subprotocol = "bearer"
|
||||
token = parts[1]
|
||||
|
||||
user_id: Optional[str] = None
|
||||
if token:
|
||||
try:
|
||||
payload = JWTAuth().verify_token(token)
|
||||
authentik_uid = payload.get("sub")
|
||||
|
||||
if authentik_uid:
|
||||
user = await user_controller.get_by_authentik_uid(authentik_uid)
|
||||
if user:
|
||||
user_id = user.id
|
||||
except Exception:
|
||||
# Auth failed - continue as anonymous
|
||||
pass
|
||||
|
||||
# Get transcript (respects user_id for private transcripts)
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
|
||||
return
|
||||
|
||||
# 2. Accept connection
|
||||
await websocket.accept()
|
||||
# 2. Accept connection (with negotiated subprotocol if present)
|
||||
await websocket.accept(subprotocol=negotiated_subprotocol)
|
||||
|
||||
# 3. Generate WebVTT context
|
||||
is_multitrack = await _get_is_multitrack(transcript)
|
||||
@@ -90,7 +114,8 @@ Answer questions about content, speakers, timeline. Include timestamps when rele
|
||||
|
||||
# Stream LLM response
|
||||
assistant_msg = ""
|
||||
async for chunk in Settings.llm.astream_chat(conversation_history):
|
||||
chat_stream = await Settings.llm.astream_chat(conversation_history)
|
||||
async for chunk in chat_stream:
|
||||
token = chunk.delta or ""
|
||||
if token:
|
||||
await websocket.send_json({"type": "token", "text": token})
|
||||
|
||||
Reference in New Issue
Block a user