Compare commits

..

11 Commits

Author SHA1 Message Date
Igor Loskutov
b84fd1fc24 cleanup 2026-01-13 12:46:03 -05:00
Igor Loskutov
3652de9fca md 2026-01-13 12:44:43 -05:00
Igor Loskutov
68df825734 test: fix WebSocket chat tests using async approach
Replaced TestClient-based tests with proper async WebSocket testing
using httpx_ws and threaded server pattern. TestClient has event loop
issues with WebSocket connections that were causing all tests to fail.

Changes:
- Rewrote all WebSocket tests to use aconnect_ws from httpx_ws
- Added chat_appserver fixture using threaded Uvicorn server
- Tests now use separate event loop in server thread
- All 6 tests now pass without asyncio/event loop errors
- Matches existing pattern from test_transcripts_rtc_ws.py

Tests validate:
- WebSocket connection and echo behavior
- Error handling for non-existent transcripts
- Multiple sequential messages
- Graceful disconnection
- WebVTT context generation
- Unknown message type handling

Closes fn-1.8 (End-to-end testing)
2026-01-12 20:17:42 -05:00
Igor Loskutov
8ca5324c1a feat: integrate TranscriptChatModal and button into transcript page 2026-01-12 20:08:59 -05:00
Igor Loskutov
39e0b89e67 feat: add TranscriptChatModal and TranscriptChatButton components 2026-01-12 19:59:01 -05:00
Igor Loskutov
544793a24f chore: mark fn-1.5 as done (Frontend WebSocket hook)
Task fn-1.5 completed - useTranscriptChat React hook already implemented in commit 2dfe82af.

Hook provides:
- WebSocket connection to /v1/transcripts/{id}/chat endpoint
- Token streaming with ref-based accumulation
- Message history management (user + assistant)
- Memory leak prevention with isMountedRef
- TypeScript type safety
- Proper WebSocket lifecycle and cleanup

Updated task documentation with acceptance criteria and evidence.
2026-01-12 19:49:14 -05:00
Igor Loskutov
088451645a chore: mark fn-1.4 as done (WebSocket route registration) 2026-01-12 19:42:26 -05:00
Igor Loskutov
2dfe82afbc feat: add useTranscriptChat WebSocket hook
Task 5: Frontend WebSocket Hook
- Creates React hook for bidirectional chat WebSocket
- Handles token streaming with proper state accumulation
- Manages conversation history (user + assistant messages)
- Prevents memory leaks with isMounted check
- Proper cleanup on unmount
- Type-safe Message interface

Validated:
- No React dependency issues (removed currentStreamingText from deps)
- No stale closure bugs (using ref for streaming text)
- Proper mounted state tracking
- Lint passes with no errors
- TypeScript types correctly defined
- WebSocket cleanup on unmount

~100 lines
2026-01-12 18:44:09 -05:00
Igor Loskutov
b461ebb488 feat: register transcript chat WebSocket route
- Import transcripts_chat router
- Register /v1/transcripts/{id}/chat endpoint
- Completes LLM streaming integration (fn-1.3)
2026-01-12 18:41:11 -05:00
Igor Loskutov
0b5112cabc feat: add LLM streaming integration to transcript chat
Task 3: LLM Streaming Integration

- Import Settings, ChatMessage, MessageRole from llama-index
- Configure LLM with temperature 0.7 on connection
- Build system message with WebVTT transcript context (max 15k chars)
- Initialize conversation history with system message
- Handle 'message' type from client to trigger LLM streaming
- Stream LLM response using Settings.llm.astream_chat()
- Send tokens incrementally via 'token' messages
- Send 'done' message when streaming completes
- Maintain conversation history across multiple messages
- Add error handling with 'error' message type
- Add message protocol validation test

Implements Tasks 3 & 4 from TASKS.md
2026-01-12 18:28:43 -05:00
Igor Loskutov
316f7b316d feat: add WebVTT context generation to chat WebSocket endpoint
- Import topics_to_webvtt_named and recordings controller
- Add _get_is_multitrack helper function
- Generate WebVTT context on WebSocket connection
- Add get_context message type to retrieve WebVTT
- Maintain backward compatibility with echo for other messages
- Add test fixture and test for WebVTT context generation

Implements task fn-1.2: WebVTT context generation for transcript chat
2026-01-12 18:24:47 -05:00
19 changed files with 757 additions and 219 deletions

View File

@@ -1,12 +1,5 @@
# Changelog
## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20)
### Features
* worker affinity ([#819](https://github.com/Monadical-SAS/reflector/issues/819)) ([3b6540e](https://github.com/Monadical-SAS/reflector/commit/3b6540eae5b597449f98661bdf15483b77be3268))
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)

View File

@@ -34,7 +34,7 @@ services:
environment:
ENTRYPOINT: beat
hatchet-worker-cpu:
hatchet-worker:
build:
context: server
volumes:
@@ -43,20 +43,7 @@ services:
env_file:
- ./server/.env
environment:
ENTRYPOINT: hatchet-worker-cpu
depends_on:
hatchet:
condition: service_healthy
hatchet-worker-llm:
build:
context: server
volumes:
- ./server/:/app/
- /app/.venv
env_file:
- ./server/.env
environment:
ENTRYPOINT: hatchet-worker-llm
ENTRYPOINT: hatchet-worker
depends_on:
hatchet:
condition: service_healthy

View File

@@ -18,6 +18,7 @@ from reflector.views.rooms import router as rooms_router
from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_chat import router as transcripts_chat_router
from reflector.views.transcripts_participants import (
router as transcripts_participants_router,
)
@@ -90,6 +91,7 @@ app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_speaker_router, prefix="/v1")
app.include_router(transcripts_upload_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_chat_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(transcripts_process_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")

View File

@@ -0,0 +1,77 @@
"""
Run Hatchet workers for the multitrack pipeline.
Runs as a separate process, just like Celery workers.
Usage:
uv run -m reflector.hatchet.run_workers
# Or via docker:
docker compose exec server uv run -m reflector.hatchet.run_workers
"""
import signal
import sys
from hatchet_sdk.rate_limit import RateLimitDuration
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND
from reflector.logger import logger
from reflector.settings import settings
def main() -> None:
"""Start Hatchet worker polling."""
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting workers")
sys.exit(1)
if not settings.HATCHET_CLIENT_TOKEN:
logger.error("HATCHET_CLIENT_TOKEN is not set")
sys.exit(1)
logger.info(
"Starting Hatchet workers",
debug=settings.HATCHET_DEBUG,
)
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
# Can't use lazy init: decorators need the client object when function is defined.
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows import ( # noqa: PLC0415
daily_multitrack_pipeline,
subject_workflow,
topic_chunk_workflow,
track_workflow,
)
hatchet = HatchetClientManager.get_client()
hatchet.rate_limits.put(
LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND
)
worker = hatchet.worker(
"reflector-pipeline-worker",
workflows=[
daily_multitrack_pipeline,
subject_workflow,
topic_chunk_workflow,
track_workflow,
],
)
def shutdown_handler(signum: int, frame) -> None:
logger.info("Received shutdown signal, stopping workers...")
# Worker cleanup happens automatically on exit
sys.exit(0)
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
logger.info("Starting Hatchet worker polling...")
worker.start()
if __name__ == "__main__":
main()

View File

@@ -1,48 +0,0 @@
"""
CPU-heavy worker pool for audio processing tasks.
Handles ONLY: mixdown_tracks
Configuration:
- slots=1: Only mixdown (already serialized globally with max_runs=1)
- Worker affinity: pool=cpu-heavy
"""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline,
)
from reflector.logger import logger
from reflector.settings import settings
def main():
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting CPU workers")
return
hatchet = HatchetClientManager.get_client()
logger.info(
"Starting Hatchet CPU worker pool (mixdown only)",
worker_name="cpu-worker-pool",
slots=1,
labels={"pool": "cpu-heavy"},
)
cpu_worker = hatchet.worker(
"cpu-worker-pool",
slots=1, # Only 1 mixdown at a time (already serialized globally)
labels={
"pool": "cpu-heavy",
},
workflows=[daily_multitrack_pipeline],
)
try:
cpu_worker.start()
except KeyboardInterrupt:
logger.info("Received shutdown signal, stopping CPU workers...")
if __name__ == "__main__":
main()

View File

@@ -1,56 +0,0 @@
"""
LLM/I/O worker pool for all non-CPU tasks.
Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration)
"""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline,
)
from reflector.hatchet.workflows.subject_processing import subject_workflow
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
from reflector.hatchet.workflows.track_processing import track_workflow
from reflector.logger import logger
from reflector.settings import settings
SLOTS = 10
WORKER_NAME = "llm-worker-pool"
POOL = "llm-io"
def main():
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting LLM workers")
return
hatchet = HatchetClientManager.get_client()
logger.info(
"Starting Hatchet LLM worker pool (all tasks except mixdown)",
worker_name=WORKER_NAME,
slots=SLOTS,
labels={"pool": POOL},
)
llm_worker = hatchet.worker(
WORKER_NAME,
slots=SLOTS, # not all slots are probably used
labels={
"pool": POOL,
},
workflows=[
daily_multitrack_pipeline,
topic_chunk_workflow,
subject_workflow,
track_workflow,
],
)
try:
llm_worker.start()
except KeyboardInterrupt:
logger.info("Received shutdown signal, stopping LLM workers...")
if __name__ == "__main__":
main()

View File

@@ -23,12 +23,7 @@ from pathlib import Path
from typing import Any, Callable, Coroutine, Protocol, TypeVar
import httpx
from hatchet_sdk import (
ConcurrencyExpression,
ConcurrencyLimitStrategy,
Context,
)
from hatchet_sdk.labels import DesiredWorkerLabel
from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.dailyco_api.client import DailyApiClient
@@ -472,20 +467,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
parents=[process_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
retries=3,
desired_worker_labels={
"pool": DesiredWorkerLabel(
value="cpu-heavy",
required=True,
weight=100,
),
},
concurrency=[
ConcurrencyExpression(
expression="'mixdown-global'",
max_runs=1, # serialize mixdown to prevent resource contention
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue
)
],
)
@with_error_handling(TaskName.MIXDOWN_TRACKS)
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:

View File

@@ -7,11 +7,7 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
from datetime import timedelta
from hatchet_sdk import (
ConcurrencyExpression,
ConcurrencyLimitStrategy,
Context,
)
from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context
from hatchet_sdk.rate_limit import RateLimit
from pydantic import BaseModel
@@ -38,13 +34,11 @@ hatchet = HatchetClientManager.get_client()
topic_chunk_workflow = hatchet.workflow(
name="TopicChunkProcessing",
input_validator=TopicChunkInput,
concurrency=[
ConcurrencyExpression(
expression="'global'", # constant string = global limit across all runs
max_runs=20,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
)
],
concurrency=ConcurrencyExpression(
expression="'global'", # constant string = global limit across all runs
max_runs=20,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
),
)

View File

@@ -0,0 +1,133 @@
"""
Transcripts chat API
====================
WebSocket endpoint for bidirectional chat with LLM about transcript content.
"""
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from reflector.auth.auth_jwt import JWTAuth
from reflector.db.recordings import recordings_controller
from reflector.db.transcripts import transcripts_controller
from reflector.db.users import user_controller
from reflector.llm import LLM
from reflector.settings import settings
from reflector.utils.transcript_formats import topics_to_webvtt_named
router = APIRouter()
async def _get_is_multitrack(transcript) -> bool:
"""Detect if transcript is from multitrack recording."""
if not transcript.recording_id:
return False
recording = await recordings_controller.get_by_id(transcript.recording_id)
return recording is not None and recording.is_multitrack
@router.websocket("/transcripts/{transcript_id}/chat")
async def transcript_chat_websocket(
transcript_id: str,
websocket: WebSocket,
):
"""WebSocket endpoint for chatting with LLM about transcript content."""
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
token: Optional[str] = None
negotiated_subprotocol: Optional[str] = None
if len(parts) >= 2 and parts[0].lower() == "bearer":
negotiated_subprotocol = "bearer"
token = parts[1]
user_id: Optional[str] = None
if token:
try:
payload = JWTAuth().verify_token(token)
authentik_uid = payload.get("sub")
if authentik_uid:
user = await user_controller.get_by_authentik_uid(authentik_uid)
if user:
user_id = user.id
except Exception:
# Auth failed - continue as anonymous
pass
# Get transcript (respects user_id for private transcripts)
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
return
# 2. Accept connection (with negotiated subprotocol if present)
await websocket.accept(subprotocol=negotiated_subprotocol)
# 3. Generate WebVTT context
is_multitrack = await _get_is_multitrack(transcript)
webvtt = topics_to_webvtt_named(
transcript.topics, transcript.participants, is_multitrack
)
# Truncate if needed (15k char limit for POC)
webvtt_truncated = webvtt[:15000] if len(webvtt) > 15000 else webvtt
# 4. Configure LLM
llm = LLM(settings=settings, temperature=0.7)
# 5. System message with transcript context
system_msg = f"""You are analyzing this meeting transcript (WebVTT):
{webvtt_truncated}
Answer questions about content, speakers, timeline. Include timestamps when relevant."""
# 6. Conversation history
conversation_history = [ChatMessage(role=MessageRole.SYSTEM, content=system_msg)]
try:
# 7. Message loop
while True:
data = await websocket.receive_json()
if data.get("type") == "get_context":
# Return WebVTT context (for debugging/testing)
await websocket.send_json({"type": "context", "webvtt": webvtt})
continue
if data.get("type") != "message":
# Echo unknown types for backward compatibility
await websocket.send_json({"type": "echo", "data": data})
continue
# Add user message to history
user_msg = ChatMessage(role=MessageRole.USER, content=data.get("text", ""))
conversation_history.append(user_msg)
# Stream LLM response
assistant_msg = ""
chat_stream = await Settings.llm.astream_chat(conversation_history)
async for chunk in chat_stream:
token = chunk.delta or ""
if token:
await websocket.send_json({"type": "token", "text": token})
assistant_msg += token
# Save assistant response to history
conversation_history.append(
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_msg)
)
await websocket.send_json({"type": "done"})
except WebSocketDisconnect:
pass
except Exception as e:
await websocket.send_json({"type": "error", "message": str(e)})

View File

@@ -11,6 +11,7 @@ broadcast messages to all connected websockets.
import asyncio
import json
import threading
import redis.asyncio as redis
from fastapi import WebSocket
@@ -97,10 +98,8 @@ class WebsocketManager:
async def _pubsub_data_reader(self, pubsub_subscriber):
while True:
# timeout=1.0 prevents tight CPU loop when no messages available
message = await pubsub_subscriber.get_message(
ignore_subscribe_messages=True,
timeout=1.0,
ignore_subscribe_messages=True
)
if message is not None:
room_id = message["channel"].decode("utf-8")
@@ -110,38 +109,29 @@ class WebsocketManager:
await socket.send_json(data)
# Process-global singleton to ensure only one WebsocketManager instance exists.
# Multiple instances would cause resource leaks and CPU issues.
_ws_manager: WebsocketManager | None = None
def get_ws_manager() -> WebsocketManager:
"""
Returns the global WebsocketManager singleton.
Returns the WebsocketManager instance for managing websockets.
Creates instance on first call, subsequent calls return cached instance.
Thread-safe via GIL. Concurrent initialization may create duplicate
instances but last write wins (acceptable for this use case).
This function initializes and returns the WebsocketManager instance,
which is responsible for managing websockets and handling websocket
connections.
Returns:
WebsocketManager: The global WebsocketManager instance.
WebsocketManager: The initialized WebsocketManager instance.
Raises:
ImportError: If the 'reflector.settings' module cannot be imported.
RedisConnectionError: If there is an error connecting to the Redis server.
"""
global _ws_manager
local = threading.local()
if hasattr(local, "ws_manager"):
return local.ws_manager
if _ws_manager is not None:
return _ws_manager
# No lock needed - GIL makes this safe enough
# Worst case: race creates two instances, last assignment wins
pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
)
_ws_manager = WebsocketManager(pubsub_client=pubsub_client)
return _ws_manager
def reset_ws_manager() -> None:
"""Reset singleton for testing. DO NOT use in production."""
global _ws_manager
_ws_manager = None
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
local.ws_manager = ws_manager
return ws_manager

View File

@@ -7,10 +7,8 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then
uv run celery -A reflector.worker.app worker --loglevel=info
elif [ "${ENTRYPOINT}" = "beat" ]; then
uv run celery -A reflector.worker.app beat --loglevel=info
elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then
uv run python -m reflector.hatchet.run_workers_cpu
elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then
uv run python -m reflector.hatchet.run_workers_llm
elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then
uv run python -m reflector.hatchet.run_workers
else
echo "Unknown command"
fi

View File

@@ -1,5 +1,6 @@
import os
from contextlib import asynccontextmanager
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest
@@ -332,18 +333,11 @@ def celery_enable_logging():
@pytest.fixture(scope="session")
def celery_config():
# Use Redis for chord/group task execution (memory:// broker doesn't support chords)
# Redis must be running - start with: docker compose up -d redis
import os
redis_host = os.environ.get("REDIS_HOST", "localhost")
redis_port = os.environ.get("REDIS_PORT", "6379")
# Use db 2 to avoid conflicts with main app
redis_url = f"redis://{redis_host}:{redis_port}/2"
yield {
"broker_url": redis_url,
"result_backend": redis_url,
}
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
@@ -376,12 +370,9 @@ async def ws_manager_in_memory(monkeypatch):
def __init__(self, queue: asyncio.Queue):
self.queue = queue
async def get_message(
self, ignore_subscribe_messages: bool = True, timeout: float | None = None
):
wait_timeout = timeout if timeout is not None else 0.05
async def get_message(self, ignore_subscribe_messages: bool = True):
try:
return await asyncio.wait_for(self.queue.get(), timeout=wait_timeout)
return await asyncio.wait_for(self.queue.get(), timeout=0.05)
except Exception:
return None

View File

@@ -0,0 +1,234 @@
"""Tests for transcript chat WebSocket endpoint."""
import asyncio
import threading
import time
from pathlib import Path
import pytest
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
from reflector.db.transcripts import (
SourceKind,
TranscriptParticipant,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Word
@pytest.fixture
def chat_appserver(tmpdir, setup_database):
"""Start a real HTTP server for WebSocket testing."""
from reflector.app import app
from reflector.db import get_database
from reflector.settings import settings
DATA_DIR = settings.DATA_DIR
settings.DATA_DIR = Path(tmpdir)
# Start server in separate thread with its own event loop
host = "127.0.0.1"
port = 1256 # Different port from rtc tests
server_started = threading.Event()
server_exception = None
server_instance = None
def run_server():
nonlocal server_exception, server_instance
try:
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
# Initialize database connection in this event loop
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait for server to start
server_started.wait(timeout=30)
if server_exception:
raise server_exception
# Wait for server to be fully ready
time.sleep(1)
yield server_instance, host, port
# Stop server
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
settings.DATA_DIR = DATA_DIR
@pytest.fixture
async def test_transcript(setup_database):
"""Create a test transcript for WebSocket tests."""
transcript = await transcripts_controller.add(
name="Test Transcript for Chat", source_kind=SourceKind.FILE
)
return transcript
@pytest.fixture
async def test_transcript_with_content(setup_database):
"""Create a test transcript with actual content for WebVTT generation."""
transcript = await transcripts_controller.add(
name="Test Transcript with Content", source_kind=SourceKind.FILE
)
# Add participants
await transcripts_controller.update(
transcript,
{
"participants": [
TranscriptParticipant(id="1", speaker=0, name="Alice").model_dump(),
TranscriptParticipant(id="2", speaker=1, name="Bob").model_dump(),
]
},
)
# Add topic with words
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Introduction",
summary="Opening remarks",
timestamp=0.0,
words=[
Word(text="Hello ", start=0.0, end=1.0, speaker=0),
Word(text="everyone.", start=1.0, end=2.0, speaker=0),
Word(text="Hi ", start=2.0, end=3.0, speaker=1),
Word(text="there!", start=3.0, end=4.0, speaker=1),
],
),
)
return transcript
@pytest.mark.asyncio
async def test_chat_websocket_connection_success(test_transcript, chat_appserver):
"""Test successful WebSocket connection to chat endpoint."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send unknown message type to test echo behavior
await ws.send_json({"type": "test", "text": "Hello"})
# Should receive echo for unknown types
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == "test"
@pytest.mark.asyncio
async def test_chat_websocket_nonexistent_transcript(chat_appserver):
"""Test WebSocket connection fails for nonexistent transcript."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
# Connection should fail or disconnect immediately for non-existent transcript
# Different behavior from successful connection
with pytest.raises(Exception): # Will raise on connection or first operation
async with aconnect_ws(f"{base_url}/transcripts/nonexistent-id/chat") as ws:
await ws.send_json({"type": "message", "text": "Hello"})
await ws.receive_json()
@pytest.mark.asyncio
async def test_chat_websocket_multiple_messages(test_transcript, chat_appserver):
"""Test sending multiple messages through WebSocket."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send multiple unknown message types (testing echo behavior)
messages = ["First message", "Second message", "Third message"]
for i, msg in enumerate(messages):
await ws.send_json({"type": f"test{i}", "text": msg})
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == f"test{i}"
assert response["data"]["text"] == msg
@pytest.mark.asyncio
async def test_chat_websocket_disconnect_graceful(test_transcript, chat_appserver):
"""Test WebSocket disconnects gracefully."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
await ws.send_json({"type": "message", "text": "Hello"})
await ws.receive_json()
# Close handled by context manager - should not raise
@pytest.mark.asyncio
async def test_chat_websocket_context_generation(
test_transcript_with_content, chat_appserver
):
"""Test WebVTT context is generated on connection."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(
f"{base_url}/transcripts/{test_transcript_with_content.id}/chat"
) as ws:
# Request context
await ws.send_json({"type": "get_context"})
# Receive context response
response = await ws.receive_json()
assert response["type"] == "context"
assert "webvtt" in response
# Verify WebVTT format
webvtt = response["webvtt"]
assert webvtt.startswith("WEBVTT")
assert "<v Alice>" in webvtt
assert "<v Bob>" in webvtt
assert "Hello everyone." in webvtt
assert "Hi there!" in webvtt
@pytest.mark.asyncio
async def test_chat_websocket_unknown_message_type(test_transcript, chat_appserver):
"""Test unknown message types are echoed back."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send unknown message type
await ws.send_json({"type": "unknown", "data": "test"})
# Should receive echo
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == "unknown"

View File

@@ -115,7 +115,9 @@ def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker)
settings.DATA_DIR = DATA_DIR
# Using celery_includes from conftest.py which includes both pipelines
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("setup_database")

View File

@@ -56,12 +56,7 @@ def appserver_ws_user(setup_database):
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=2.0)
# Reset global singleton for test isolation
from reflector.ws_manager import reset_ws_manager
reset_ws_manager()
server_thread.join(timeout=30)
@pytest.fixture(autouse=True)
@@ -138,11 +133,6 @@ async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user
# Connect and then trigger an event via HTTP create
async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws:
# Give Redis pubsub time to establish subscription before publishing
import asyncio
await asyncio.sleep(0.2)
# Emit an event to the user's room via a standard HTTP action
from httpx import AsyncClient
@@ -160,7 +150,6 @@ async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user
"email": "user-abc@example.com",
}
# Use in-memory client (global singleton makes it share ws_manager)
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
resp = await ac.post("/transcripts", json={"name": "WS Test"})

20
server/uv.lock generated
View File

@@ -330,6 +330,26 @@ name = "av"
version = "14.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/86/f6/0b473dab52dfdea05f28f3578b1c56b6c796ce85e76951bab7c4e38d5a74/av-14.4.0.tar.gz", hash = "sha256:3ecbf803a7fdf67229c0edada0830d6bfaea4d10bfb24f0c3f4e607cd1064b42", size = 3892203 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/18/8a/d57418b686ffd05fabd5a0a9cfa97e63b38c35d7101af00e87c51c8cc43c/av-14.4.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5b21d5586a88b9fce0ab78e26bd1c38f8642f8e2aad5b35e619f4d202217c701", size = 19965048 },
{ url = "https://files.pythonhosted.org/packages/f5/aa/3f878b0301efe587e9b07bb773dd6b47ef44ca09a3cffb4af50c08a170f3/av-14.4.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:cf8762d90b0f94a20c9f6e25a94f1757db5a256707964dfd0b1d4403e7a16835", size = 23750064 },
{ url = "https://files.pythonhosted.org/packages/9a/b4/6fe94a31f9ed3a927daa72df67c7151968587106f30f9f8fcd792b186633/av-14.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0ac9f08920c7bbe0795319689d901e27cb3d7870b9a0acae3f26fc9daa801a6", size = 33648775 },
{ url = "https://files.pythonhosted.org/packages/6c/f3/7f3130753521d779450c935aec3f4beefc8d4645471159f27b54e896470c/av-14.4.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a56d9ad2afdb638ec0404e962dc570960aae7e08ae331ad7ff70fbe99a6cf40e", size = 32216915 },
{ url = "https://files.pythonhosted.org/packages/f8/9a/8ffabfcafb42154b4b3a67d63f9b69e68fa8c34cb39ddd5cb813dd049ed4/av-14.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bed513cbcb3437d0ae47743edc1f5b4a113c0b66cdd4e1aafc533abf5b2fbf2", size = 35287279 },
{ url = "https://files.pythonhosted.org/packages/ad/11/7023ba0a2ca94a57aedf3114ab8cfcecb0819b50c30982a4c5be4d31df41/av-14.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d030c2d3647931e53d51f2f6e0fcf465263e7acf9ec6e4faa8dbfc77975318c3", size = 36294683 },
{ url = "https://files.pythonhosted.org/packages/3d/fa/b8ac9636bd5034e2b899354468bef9f4dadb067420a16d8a493a514b7817/av-14.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1cc21582a4f606271d8c2036ec7a6247df0831050306c55cf8a905701d0f0474", size = 34552391 },
{ url = "https://files.pythonhosted.org/packages/fb/29/0db48079c207d1cba7a2783896db5aec3816e17de55942262c244dffbc0f/av-14.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce7c9cd452153d36f1b1478f904ed5f9ab191d76db873bdd3a597193290805d4", size = 37265250 },
{ url = "https://files.pythonhosted.org/packages/1c/55/715858c3feb7efa4d667ce83a829c8e6ee3862e297fb2b568da3f968639d/av-14.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd261e31cc6b43ca722f80656c39934199d8f2eb391e0147e704b6226acebc29", size = 27925845 },
{ url = "https://files.pythonhosted.org/packages/a6/75/b8641653780336c90ba89e5352cac0afa6256a86a150c7703c0b38851c6d/av-14.4.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:a53e682b239dd23b4e3bc9568cfb1168fc629ab01925fdb2e7556eb426339e94", size = 19954125 },
{ url = "https://files.pythonhosted.org/packages/99/e6/37fe6fa5853a48d54d749526365780a63a4bc530be6abf2115e3a21e292a/av-14.4.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5aa0b901751a32703fa938d2155d56ce3faf3630e4a48d238b35d2f7e49e5395", size = 23751479 },
{ url = "https://files.pythonhosted.org/packages/f7/75/9a5f0e6bda5f513b62bafd1cff2b495441a8b07ab7fb7b8e62f0c0d1683f/av-14.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3b316fed3597675fe2aacfed34e25fc9d5bb0196dc8c0b014ae5ed4adda48de", size = 33801401 },
{ url = "https://files.pythonhosted.org/packages/6a/c9/e4df32a2ad1cb7f3a112d0ed610c5e43c89da80b63c60d60e3dc23793ec0/av-14.4.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a587b5c5014c3c0e16143a0f8d99874e46b5d0c50db6111aa0b54206b5687c81", size = 32364330 },
{ url = "https://files.pythonhosted.org/packages/ca/f0/64e7444a41817fde49a07d0239c033f7e9280bec4a4bb4784f5c79af95e6/av-14.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d53f75e8ac1ec8877a551c0db32a83c0aaeae719d05285281eaaba211bbc30", size = 35519508 },
{ url = "https://files.pythonhosted.org/packages/c2/a8/a370099daa9033a3b6f9b9bd815304b3d8396907a14d09845f27467ba138/av-14.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c8558cfde79dd8fc92d97c70e0f0fa8c94c7a66f68ae73afdf58598f0fe5e10d", size = 36448593 },
{ url = "https://files.pythonhosted.org/packages/27/bb/edb6ceff8fa7259cb6330c51dbfbc98dd1912bd6eb5f7bc05a4bb14a9d6e/av-14.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:455b6410dea0ab2d30234ffb28df7d62ca3cdf10708528e247bec3a4cdcced09", size = 34701485 },
{ url = "https://files.pythonhosted.org/packages/a7/8a/957da1f581aa1faa9a5dfa8b47ca955edb47f2b76b949950933b457bfa1d/av-14.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1661efbe9d975f927b8512d654704223d936f39016fad2ddab00aee7c40f412c", size = 37521981 },
{ url = "https://files.pythonhosted.org/packages/28/76/3f1cf0568592f100fd68eb40ed8c491ce95ca3c1378cc2d4c1f6d1bd295d/av-14.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:fbbeef1f421a3461086853d6464ad5526b56ffe8ccb0ab3fd0a1f121dfbf26ad", size = 27925944 },
]
[[package]]
name = "banks"

View File

@@ -0,0 +1,103 @@
"use client";
import { useState } from "react";
import { Box, Dialog, Input, IconButton } from "@chakra-ui/react";
import { MessageCircle } from "lucide-react";
import Markdown from "react-markdown";
import "../../styles/markdown.css";
import type { Message } from "./useTranscriptChat";
interface TranscriptChatModalProps {
open: boolean;
onClose: () => void;
messages: Message[];
sendMessage: (text: string) => void;
isStreaming: boolean;
currentStreamingText: string;
}
export function TranscriptChatModal({
open,
onClose,
messages,
sendMessage,
isStreaming,
currentStreamingText,
}: TranscriptChatModalProps) {
const [input, setInput] = useState("");
const handleSend = () => {
if (!input.trim()) return;
sendMessage(input);
setInput("");
};
return (
<Dialog.Root open={open} onOpenChange={(e) => !e.open && onClose()}>
<Dialog.Backdrop />
<Dialog.Positioner>
<Dialog.Content maxW="500px" h="600px">
<Dialog.Header>Transcript Chat</Dialog.Header>
<Dialog.Body overflowY="auto">
{messages.map((msg) => (
<Box
key={msg.id}
p={3}
mb={2}
bg={msg.role === "user" ? "blue.50" : "gray.50"}
borderRadius="md"
>
{msg.role === "user" ? (
msg.text
) : (
<div className="markdown">
<Markdown>{msg.text}</Markdown>
</div>
)}
</Box>
))}
{isStreaming && (
<Box p={3} bg="gray.50" borderRadius="md">
<div className="markdown">
<Markdown>{currentStreamingText}</Markdown>
</div>
<Box as="span" className="animate-pulse">
</Box>
</Box>
)}
</Dialog.Body>
<Dialog.Footer>
<Input
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleSend()}
placeholder="Ask about transcript..."
disabled={isStreaming}
/>
</Dialog.Footer>
</Dialog.Content>
</Dialog.Positioner>
</Dialog.Root>
);
}
export function TranscriptChatButton({ onClick }: { onClick: () => void }) {
return (
<IconButton
position="fixed"
bottom="24px"
right="24px"
onClick={onClick}
size="lg"
colorPalette="blue"
borderRadius="full"
aria-label="Open chat"
>
<MessageCircle />
</IconButton>
);
}

View File

@@ -18,9 +18,15 @@ import {
Skeleton,
Text,
Spinner,
useDisclosure,
} from "@chakra-ui/react";
import { useTranscriptGet } from "../../../lib/apiHooks";
import { TranscriptStatus } from "../../../lib/transcript";
import {
TranscriptChatModal,
TranscriptChatButton,
} from "../TranscriptChatModal";
import { useTranscriptChat } from "../useTranscriptChat";
type TranscriptDetails = {
params: Promise<{
@@ -53,6 +59,9 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const [finalSummaryElement, setFinalSummaryElement] =
useState<HTMLDivElement | null>(null);
const { open, onOpen, onClose } = useDisclosure();
const chat = useTranscriptChat(transcriptId);
useEffect(() => {
if (!waiting || !transcript.data) return;
@@ -119,6 +128,15 @@ export default function TranscriptDetails(details: TranscriptDetails) {
return (
<>
<TranscriptChatModal
open={open}
onClose={onClose}
messages={chat.messages}
sendMessage={chat.sendMessage}
isStreaming={chat.isStreaming}
currentStreamingText={chat.currentStreamingText}
/>
<TranscriptChatButton onClick={onOpen} />
<Grid
templateColumns="1fr"
templateRows="auto minmax(0, 1fr)"

View File

@@ -0,0 +1,130 @@
"use client";
import { useEffect, useState, useRef } from "react";
import { getSession } from "next-auth/react";
import { WEBSOCKET_URL } from "../../lib/apiClient";
import { assertExtendedToken } from "../../lib/types";
export type Message = {
id: string;
role: "user" | "assistant";
text: string;
timestamp: Date;
};
export type UseTranscriptChat = {
messages: Message[];
sendMessage: (text: string) => void;
isStreaming: boolean;
currentStreamingText: string;
};
export const useTranscriptChat = (transcriptId: string): UseTranscriptChat => {
const [messages, setMessages] = useState<Message[]>([]);
const [isStreaming, setIsStreaming] = useState(false);
const [currentStreamingText, setCurrentStreamingText] = useState("");
const wsRef = useRef<WebSocket | null>(null);
const streamingTextRef = useRef<string>("");
const isMountedRef = useRef<boolean>(true);
useEffect(() => {
isMountedRef.current = true;
const connectWebSocket = async () => {
const url = `${WEBSOCKET_URL}/v1/transcripts/${transcriptId}/chat`;
// Get auth token for WebSocket subprotocol
let protocols: string[] | undefined;
try {
const session = await getSession();
if (session) {
const token = assertExtendedToken(session).accessToken;
// Pass token via subprotocol: ["bearer", token]
protocols = ["bearer", token];
}
} catch (error) {
console.warn("Failed to get auth token for WebSocket:", error);
}
const ws = new WebSocket(url, protocols);
wsRef.current = ws;
ws.onopen = () => {
console.log("Chat WebSocket connected");
};
ws.onmessage = (event) => {
if (!isMountedRef.current) return;
const msg = JSON.parse(event.data);
switch (msg.type) {
case "token":
setIsStreaming(true);
streamingTextRef.current += msg.text;
setCurrentStreamingText(streamingTextRef.current);
break;
case "done":
// CRITICAL: Save the text BEFORE resetting the ref
// The setMessages callback may execute later, after ref is reset
const finalText = streamingTextRef.current;
setMessages((prev) => [
...prev,
{
id: Date.now().toString(),
role: "assistant",
text: finalText,
timestamp: new Date(),
},
]);
streamingTextRef.current = "";
setCurrentStreamingText("");
setIsStreaming(false);
break;
case "error":
console.error("Chat error:", msg.message);
setIsStreaming(false);
break;
}
};
ws.onerror = (error) => {
console.error("WebSocket error:", error);
};
ws.onclose = () => {
console.log("Chat WebSocket closed");
};
};
connectWebSocket();
return () => {
isMountedRef.current = false;
if (wsRef.current) {
wsRef.current.close();
}
};
}, [transcriptId]);
const sendMessage = (text: string) => {
if (!wsRef.current) return;
setMessages((prev) => [
...prev,
{
id: Date.now().toString(),
role: "user",
text,
timestamp: new Date(),
},
]);
wsRef.current.send(JSON.stringify({ type: "message", text }));
};
return { messages, sendMessage, isStreaming, currentStreamingText };
};