mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-05 18:36:45 +00:00
Compare commits
5 Commits
feat/trans
...
fix/websoc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
454aecf270 | ||
|
|
dee1555807 | ||
| 8a293882ad | |||
| d83c4a30b4 | |||
| 3b6540eae5 |
@@ -1,5 +1,12 @@
|
||||
# Changelog
|
||||
|
||||
## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* worker affinity ([#819](https://github.com/Monadical-SAS/reflector/issues/819)) ([3b6540e](https://github.com/Monadical-SAS/reflector/commit/3b6540eae5b597449f98661bdf15483b77be3268))
|
||||
|
||||
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ services:
|
||||
environment:
|
||||
ENTRYPOINT: beat
|
||||
|
||||
hatchet-worker:
|
||||
hatchet-worker-cpu:
|
||||
build:
|
||||
context: server
|
||||
volumes:
|
||||
@@ -43,7 +43,20 @@ services:
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
ENTRYPOINT: hatchet-worker
|
||||
ENTRYPOINT: hatchet-worker-cpu
|
||||
depends_on:
|
||||
hatchet:
|
||||
condition: service_healthy
|
||||
hatchet-worker-llm:
|
||||
build:
|
||||
context: server
|
||||
volumes:
|
||||
- ./server/:/app/
|
||||
- /app/.venv
|
||||
env_file:
|
||||
- ./server/.env
|
||||
environment:
|
||||
ENTRYPOINT: hatchet-worker-llm
|
||||
depends_on:
|
||||
hatchet:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -18,7 +18,6 @@ from reflector.views.rooms import router as rooms_router
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
from reflector.views.transcripts import router as transcripts_router
|
||||
from reflector.views.transcripts_audio import router as transcripts_audio_router
|
||||
from reflector.views.transcripts_chat import router as transcripts_chat_router
|
||||
from reflector.views.transcripts_participants import (
|
||||
router as transcripts_participants_router,
|
||||
)
|
||||
@@ -91,7 +90,6 @@ app.include_router(transcripts_participants_router, prefix="/v1")
|
||||
app.include_router(transcripts_speaker_router, prefix="/v1")
|
||||
app.include_router(transcripts_upload_router, prefix="/v1")
|
||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||
app.include_router(transcripts_chat_router, prefix="/v1")
|
||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||
app.include_router(transcripts_process_router, prefix="/v1")
|
||||
app.include_router(user_router, prefix="/v1")
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
"""
|
||||
Run Hatchet workers for the multitrack pipeline.
|
||||
Runs as a separate process, just like Celery workers.
|
||||
|
||||
Usage:
|
||||
uv run -m reflector.hatchet.run_workers
|
||||
|
||||
# Or via docker:
|
||||
docker compose exec server uv run -m reflector.hatchet.run_workers
|
||||
"""
|
||||
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from hatchet_sdk.rate_limit import RateLimitDuration
|
||||
|
||||
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start Hatchet worker polling."""
|
||||
if not settings.HATCHET_ENABLED:
|
||||
logger.error("HATCHET_ENABLED is False, not starting workers")
|
||||
sys.exit(1)
|
||||
|
||||
if not settings.HATCHET_CLIENT_TOKEN:
|
||||
logger.error("HATCHET_CLIENT_TOKEN is not set")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(
|
||||
"Starting Hatchet workers",
|
||||
debug=settings.HATCHET_DEBUG,
|
||||
)
|
||||
|
||||
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
|
||||
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
|
||||
# Can't use lazy init: decorators need the client object when function is defined.
|
||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||
from reflector.hatchet.workflows import ( # noqa: PLC0415
|
||||
daily_multitrack_pipeline,
|
||||
subject_workflow,
|
||||
topic_chunk_workflow,
|
||||
track_workflow,
|
||||
)
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
hatchet.rate_limits.put(
|
||||
LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND
|
||||
)
|
||||
|
||||
worker = hatchet.worker(
|
||||
"reflector-pipeline-worker",
|
||||
workflows=[
|
||||
daily_multitrack_pipeline,
|
||||
subject_workflow,
|
||||
topic_chunk_workflow,
|
||||
track_workflow,
|
||||
],
|
||||
)
|
||||
|
||||
def shutdown_handler(signum: int, frame) -> None:
|
||||
logger.info("Received shutdown signal, stopping workers...")
|
||||
# Worker cleanup happens automatically on exit
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, shutdown_handler)
|
||||
signal.signal(signal.SIGTERM, shutdown_handler)
|
||||
|
||||
logger.info("Starting Hatchet worker polling...")
|
||||
worker.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
48
server/reflector/hatchet/run_workers_cpu.py
Normal file
48
server/reflector/hatchet/run_workers_cpu.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
CPU-heavy worker pool for audio processing tasks.
|
||||
Handles ONLY: mixdown_tracks
|
||||
|
||||
Configuration:
|
||||
- slots=1: Only mixdown (already serialized globally with max_runs=1)
|
||||
- Worker affinity: pool=cpu-heavy
|
||||
"""
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
daily_multitrack_pipeline,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
def main():
|
||||
if not settings.HATCHET_ENABLED:
|
||||
logger.error("HATCHET_ENABLED is False, not starting CPU workers")
|
||||
return
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
logger.info(
|
||||
"Starting Hatchet CPU worker pool (mixdown only)",
|
||||
worker_name="cpu-worker-pool",
|
||||
slots=1,
|
||||
labels={"pool": "cpu-heavy"},
|
||||
)
|
||||
|
||||
cpu_worker = hatchet.worker(
|
||||
"cpu-worker-pool",
|
||||
slots=1, # Only 1 mixdown at a time (already serialized globally)
|
||||
labels={
|
||||
"pool": "cpu-heavy",
|
||||
},
|
||||
workflows=[daily_multitrack_pipeline],
|
||||
)
|
||||
|
||||
try:
|
||||
cpu_worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received shutdown signal, stopping CPU workers...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
56
server/reflector/hatchet/run_workers_llm.py
Normal file
56
server/reflector/hatchet/run_workers_llm.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
LLM/I/O worker pool for all non-CPU tasks.
|
||||
Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration)
|
||||
"""
|
||||
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
daily_multitrack_pipeline,
|
||||
)
|
||||
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
||||
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
||||
from reflector.hatchet.workflows.track_processing import track_workflow
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
SLOTS = 10
|
||||
WORKER_NAME = "llm-worker-pool"
|
||||
POOL = "llm-io"
|
||||
|
||||
|
||||
def main():
|
||||
if not settings.HATCHET_ENABLED:
|
||||
logger.error("HATCHET_ENABLED is False, not starting LLM workers")
|
||||
return
|
||||
|
||||
hatchet = HatchetClientManager.get_client()
|
||||
|
||||
logger.info(
|
||||
"Starting Hatchet LLM worker pool (all tasks except mixdown)",
|
||||
worker_name=WORKER_NAME,
|
||||
slots=SLOTS,
|
||||
labels={"pool": POOL},
|
||||
)
|
||||
|
||||
llm_worker = hatchet.worker(
|
||||
WORKER_NAME,
|
||||
slots=SLOTS, # not all slots are probably used
|
||||
labels={
|
||||
"pool": POOL,
|
||||
},
|
||||
workflows=[
|
||||
daily_multitrack_pipeline,
|
||||
topic_chunk_workflow,
|
||||
subject_workflow,
|
||||
track_workflow,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
llm_worker.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received shutdown signal, stopping LLM workers...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -23,7 +23,12 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Protocol, TypeVar
|
||||
|
||||
import httpx
|
||||
from hatchet_sdk import Context
|
||||
from hatchet_sdk import (
|
||||
ConcurrencyExpression,
|
||||
ConcurrencyLimitStrategy,
|
||||
Context,
|
||||
)
|
||||
from hatchet_sdk.labels import DesiredWorkerLabel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.dailyco_api.client import DailyApiClient
|
||||
@@ -467,6 +472,20 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
||||
parents=[process_tracks],
|
||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||
retries=3,
|
||||
desired_worker_labels={
|
||||
"pool": DesiredWorkerLabel(
|
||||
value="cpu-heavy",
|
||||
required=True,
|
||||
weight=100,
|
||||
),
|
||||
},
|
||||
concurrency=[
|
||||
ConcurrencyExpression(
|
||||
expression="'mixdown-global'",
|
||||
max_runs=1, # serialize mixdown to prevent resource contention
|
||||
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue
|
||||
)
|
||||
],
|
||||
)
|
||||
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
|
||||
@@ -7,7 +7,11 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context
|
||||
from hatchet_sdk import (
|
||||
ConcurrencyExpression,
|
||||
ConcurrencyLimitStrategy,
|
||||
Context,
|
||||
)
|
||||
from hatchet_sdk.rate_limit import RateLimit
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -34,11 +38,13 @@ hatchet = HatchetClientManager.get_client()
|
||||
topic_chunk_workflow = hatchet.workflow(
|
||||
name="TopicChunkProcessing",
|
||||
input_validator=TopicChunkInput,
|
||||
concurrency=ConcurrencyExpression(
|
||||
concurrency=[
|
||||
ConcurrencyExpression(
|
||||
expression="'global'", # constant string = global limit across all runs
|
||||
max_runs=20,
|
||||
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
"""
|
||||
Transcripts chat API
|
||||
====================
|
||||
|
||||
WebSocket endpoint for bidirectional chat with LLM about transcript content.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||
|
||||
from reflector.auth.auth_jwt import JWTAuth
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.llm import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.transcript_formats import topics_to_webvtt_named
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_is_multitrack(transcript) -> bool:
|
||||
"""Detect if transcript is from multitrack recording."""
|
||||
if not transcript.recording_id:
|
||||
return False
|
||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||
return recording is not None and recording.is_multitrack
|
||||
|
||||
|
||||
@router.websocket("/transcripts/{transcript_id}/chat")
|
||||
async def transcript_chat_websocket(
|
||||
transcript_id: str,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
"""WebSocket endpoint for chatting with LLM about transcript content."""
|
||||
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
|
||||
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
|
||||
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
|
||||
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
|
||||
token: Optional[str] = None
|
||||
negotiated_subprotocol: Optional[str] = None
|
||||
if len(parts) >= 2 and parts[0].lower() == "bearer":
|
||||
negotiated_subprotocol = "bearer"
|
||||
token = parts[1]
|
||||
|
||||
user_id: Optional[str] = None
|
||||
if token:
|
||||
try:
|
||||
payload = JWTAuth().verify_token(token)
|
||||
authentik_uid = payload.get("sub")
|
||||
|
||||
if authentik_uid:
|
||||
user = await user_controller.get_by_authentik_uid(authentik_uid)
|
||||
if user:
|
||||
user_id = user.id
|
||||
except Exception:
|
||||
# Auth failed - continue as anonymous
|
||||
pass
|
||||
|
||||
# Get transcript (respects user_id for private transcripts)
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
if not transcript:
|
||||
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
|
||||
return
|
||||
|
||||
# 2. Accept connection (with negotiated subprotocol if present)
|
||||
await websocket.accept(subprotocol=negotiated_subprotocol)
|
||||
|
||||
# 3. Generate WebVTT context
|
||||
is_multitrack = await _get_is_multitrack(transcript)
|
||||
webvtt = topics_to_webvtt_named(
|
||||
transcript.topics, transcript.participants, is_multitrack
|
||||
)
|
||||
|
||||
# Truncate if needed (15k char limit for POC)
|
||||
webvtt_truncated = webvtt[:15000] if len(webvtt) > 15000 else webvtt
|
||||
|
||||
# 4. Configure LLM
|
||||
llm = LLM(settings=settings, temperature=0.7)
|
||||
|
||||
# 5. System message with transcript context
|
||||
system_msg = f"""You are analyzing this meeting transcript (WebVTT):
|
||||
|
||||
{webvtt_truncated}
|
||||
|
||||
Answer questions about content, speakers, timeline. Include timestamps when relevant."""
|
||||
|
||||
# 6. Conversation history
|
||||
conversation_history = [ChatMessage(role=MessageRole.SYSTEM, content=system_msg)]
|
||||
|
||||
try:
|
||||
# 7. Message loop
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "get_context":
|
||||
# Return WebVTT context (for debugging/testing)
|
||||
await websocket.send_json({"type": "context", "webvtt": webvtt})
|
||||
continue
|
||||
|
||||
if data.get("type") != "message":
|
||||
# Echo unknown types for backward compatibility
|
||||
await websocket.send_json({"type": "echo", "data": data})
|
||||
continue
|
||||
|
||||
# Add user message to history
|
||||
user_msg = ChatMessage(role=MessageRole.USER, content=data.get("text", ""))
|
||||
conversation_history.append(user_msg)
|
||||
|
||||
# Stream LLM response
|
||||
assistant_msg = ""
|
||||
chat_stream = await Settings.llm.astream_chat(conversation_history)
|
||||
async for chunk in chat_stream:
|
||||
token = chunk.delta or ""
|
||||
if token:
|
||||
await websocket.send_json({"type": "token", "text": token})
|
||||
assistant_msg += token
|
||||
|
||||
# Save assistant response to history
|
||||
conversation_history.append(
|
||||
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_msg)
|
||||
)
|
||||
await websocket.send_json({"type": "done"})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
@@ -11,7 +11,6 @@ broadcast messages to all connected websockets.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
|
||||
import redis.asyncio as redis
|
||||
from fastapi import WebSocket
|
||||
@@ -98,8 +97,10 @@ class WebsocketManager:
|
||||
|
||||
async def _pubsub_data_reader(self, pubsub_subscriber):
|
||||
while True:
|
||||
# timeout=1.0 prevents tight CPU loop when no messages available
|
||||
message = await pubsub_subscriber.get_message(
|
||||
ignore_subscribe_messages=True
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=1.0,
|
||||
)
|
||||
if message is not None:
|
||||
room_id = message["channel"].decode("utf-8")
|
||||
@@ -109,29 +110,38 @@ class WebsocketManager:
|
||||
await socket.send_json(data)
|
||||
|
||||
|
||||
# Process-global singleton to ensure only one WebsocketManager instance exists.
|
||||
# Multiple instances would cause resource leaks and CPU issues.
|
||||
_ws_manager: WebsocketManager | None = None
|
||||
|
||||
|
||||
def get_ws_manager() -> WebsocketManager:
|
||||
"""
|
||||
Returns the WebsocketManager instance for managing websockets.
|
||||
Returns the global WebsocketManager singleton.
|
||||
|
||||
This function initializes and returns the WebsocketManager instance,
|
||||
which is responsible for managing websockets and handling websocket
|
||||
connections.
|
||||
Creates instance on first call, subsequent calls return cached instance.
|
||||
Thread-safe via GIL. Concurrent initialization may create duplicate
|
||||
instances but last write wins (acceptable for this use case).
|
||||
|
||||
Returns:
|
||||
WebsocketManager: The initialized WebsocketManager instance.
|
||||
|
||||
Raises:
|
||||
ImportError: If the 'reflector.settings' module cannot be imported.
|
||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||
WebsocketManager: The global WebsocketManager instance.
|
||||
"""
|
||||
local = threading.local()
|
||||
if hasattr(local, "ws_manager"):
|
||||
return local.ws_manager
|
||||
global _ws_manager
|
||||
|
||||
if _ws_manager is not None:
|
||||
return _ws_manager
|
||||
|
||||
# No lock needed - GIL makes this safe enough
|
||||
# Worst case: race creates two instances, last assignment wins
|
||||
pubsub_client = RedisPubSubManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
)
|
||||
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||
local.ws_manager = ws_manager
|
||||
return ws_manager
|
||||
_ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||
return _ws_manager
|
||||
|
||||
|
||||
def reset_ws_manager() -> None:
|
||||
"""Reset singleton for testing. DO NOT use in production."""
|
||||
global _ws_manager
|
||||
_ws_manager = None
|
||||
|
||||
@@ -7,8 +7,10 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then
|
||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
||||
uv run celery -A reflector.worker.app beat --loglevel=info
|
||||
elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then
|
||||
uv run python -m reflector.hatchet.run_workers
|
||||
elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then
|
||||
uv run python -m reflector.hatchet.run_workers_cpu
|
||||
elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then
|
||||
uv run python -m reflector.hatchet.run_workers_llm
|
||||
else
|
||||
echo "Unknown command"
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -333,10 +332,17 @@ def celery_enable_logging():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_config():
|
||||
with NamedTemporaryFile() as f:
|
||||
# Use Redis for chord/group task execution (memory:// broker doesn't support chords)
|
||||
# Redis must be running - start with: docker compose up -d redis
|
||||
import os
|
||||
|
||||
redis_host = os.environ.get("REDIS_HOST", "localhost")
|
||||
redis_port = os.environ.get("REDIS_PORT", "6379")
|
||||
# Use db 2 to avoid conflicts with main app
|
||||
redis_url = f"redis://{redis_host}:{redis_port}/2"
|
||||
yield {
|
||||
"broker_url": "memory://",
|
||||
"result_backend": f"db+sqlite:///{f.name}",
|
||||
"broker_url": redis_url,
|
||||
"result_backend": redis_url,
|
||||
}
|
||||
|
||||
|
||||
@@ -370,9 +376,12 @@ async def ws_manager_in_memory(monkeypatch):
|
||||
def __init__(self, queue: asyncio.Queue):
|
||||
self.queue = queue
|
||||
|
||||
async def get_message(self, ignore_subscribe_messages: bool = True):
|
||||
async def get_message(
|
||||
self, ignore_subscribe_messages: bool = True, timeout: float | None = None
|
||||
):
|
||||
wait_timeout = timeout if timeout is not None else 0.05
|
||||
try:
|
||||
return await asyncio.wait_for(self.queue.get(), timeout=0.05)
|
||||
return await asyncio.wait_for(self.queue.get(), timeout=wait_timeout)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,234 +0,0 @@
|
||||
"""Tests for transcript chat WebSocket endpoint."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from httpx_ws import aconnect_ws
|
||||
from uvicorn import Config, Server
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
TranscriptParticipant,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_appserver(tmpdir, setup_database):
|
||||
"""Start a real HTTP server for WebSocket testing."""
|
||||
from reflector.app import app
|
||||
from reflector.db import get_database
|
||||
from reflector.settings import settings
|
||||
|
||||
DATA_DIR = settings.DATA_DIR
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
# Start server in separate thread with its own event loop
|
||||
host = "127.0.0.1"
|
||||
port = 1256 # Different port from rtc tests
|
||||
server_started = threading.Event()
|
||||
server_exception = None
|
||||
server_instance = None
|
||||
|
||||
def run_server():
|
||||
nonlocal server_exception, server_instance
|
||||
try:
|
||||
# Create new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
config = Config(app=app, host=host, port=port, loop=loop)
|
||||
server_instance = Server(config)
|
||||
|
||||
async def start_server():
|
||||
# Initialize database connection in this event loop
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
await server_instance.serve()
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
# Signal that server is starting
|
||||
server_started.set()
|
||||
loop.run_until_complete(start_server())
|
||||
except Exception as e:
|
||||
server_exception = e
|
||||
server_started.set()
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# Wait for server to start
|
||||
server_started.wait(timeout=30)
|
||||
if server_exception:
|
||||
raise server_exception
|
||||
|
||||
# Wait for server to be fully ready
|
||||
time.sleep(1)
|
||||
|
||||
yield server_instance, host, port
|
||||
|
||||
# Stop server
|
||||
if server_instance:
|
||||
server_instance.should_exit = True
|
||||
server_thread.join(timeout=30)
|
||||
|
||||
settings.DATA_DIR = DATA_DIR
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_transcript(setup_database):
|
||||
"""Create a test transcript for WebSocket tests."""
|
||||
transcript = await transcripts_controller.add(
|
||||
name="Test Transcript for Chat", source_kind=SourceKind.FILE
|
||||
)
|
||||
return transcript
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_transcript_with_content(setup_database):
|
||||
"""Create a test transcript with actual content for WebVTT generation."""
|
||||
transcript = await transcripts_controller.add(
|
||||
name="Test Transcript with Content", source_kind=SourceKind.FILE
|
||||
)
|
||||
|
||||
# Add participants
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"participants": [
|
||||
TranscriptParticipant(id="1", speaker=0, name="Alice").model_dump(),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Bob").model_dump(),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
# Add topic with words
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Introduction",
|
||||
summary="Opening remarks",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello ", start=0.0, end=1.0, speaker=0),
|
||||
Word(text="everyone.", start=1.0, end=2.0, speaker=0),
|
||||
Word(text="Hi ", start=2.0, end=3.0, speaker=1),
|
||||
Word(text="there!", start=3.0, end=4.0, speaker=1),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
return transcript
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_websocket_connection_success(test_transcript, chat_appserver):
|
||||
"""Test successful WebSocket connection to chat endpoint."""
|
||||
server, host, port = chat_appserver
|
||||
base_url = f"ws://{host}:{port}/v1"
|
||||
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||
# Send unknown message type to test echo behavior
|
||||
await ws.send_json({"type": "test", "text": "Hello"})
|
||||
|
||||
# Should receive echo for unknown types
|
||||
response = await ws.receive_json()
|
||||
assert response["type"] == "echo"
|
||||
assert response["data"]["type"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_websocket_nonexistent_transcript(chat_appserver):
|
||||
"""Test WebSocket connection fails for nonexistent transcript."""
|
||||
server, host, port = chat_appserver
|
||||
base_url = f"ws://{host}:{port}/v1"
|
||||
|
||||
# Connection should fail or disconnect immediately for non-existent transcript
|
||||
# Different behavior from successful connection
|
||||
with pytest.raises(Exception): # Will raise on connection or first operation
|
||||
async with aconnect_ws(f"{base_url}/transcripts/nonexistent-id/chat") as ws:
|
||||
await ws.send_json({"type": "message", "text": "Hello"})
|
||||
await ws.receive_json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_websocket_multiple_messages(test_transcript, chat_appserver):
|
||||
"""Test sending multiple messages through WebSocket."""
|
||||
server, host, port = chat_appserver
|
||||
base_url = f"ws://{host}:{port}/v1"
|
||||
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||
# Send multiple unknown message types (testing echo behavior)
|
||||
messages = ["First message", "Second message", "Third message"]
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
await ws.send_json({"type": f"test{i}", "text": msg})
|
||||
response = await ws.receive_json()
|
||||
assert response["type"] == "echo"
|
||||
assert response["data"]["type"] == f"test{i}"
|
||||
assert response["data"]["text"] == msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_websocket_disconnect_graceful(test_transcript, chat_appserver):
|
||||
"""Test WebSocket disconnects gracefully."""
|
||||
server, host, port = chat_appserver
|
||||
base_url = f"ws://{host}:{port}/v1"
|
||||
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||
await ws.send_json({"type": "message", "text": "Hello"})
|
||||
await ws.receive_json()
|
||||
# Close handled by context manager - should not raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_websocket_context_generation(
|
||||
test_transcript_with_content, chat_appserver
|
||||
):
|
||||
"""Test WebVTT context is generated on connection."""
|
||||
server, host, port = chat_appserver
|
||||
base_url = f"ws://{host}:{port}/v1"
|
||||
|
||||
async with aconnect_ws(
|
||||
f"{base_url}/transcripts/{test_transcript_with_content.id}/chat"
|
||||
) as ws:
|
||||
# Request context
|
||||
await ws.send_json({"type": "get_context"})
|
||||
|
||||
# Receive context response
|
||||
response = await ws.receive_json()
|
||||
assert response["type"] == "context"
|
||||
assert "webvtt" in response
|
||||
|
||||
# Verify WebVTT format
|
||||
webvtt = response["webvtt"]
|
||||
assert webvtt.startswith("WEBVTT")
|
||||
assert "<v Alice>" in webvtt
|
||||
assert "<v Bob>" in webvtt
|
||||
assert "Hello everyone." in webvtt
|
||||
assert "Hi there!" in webvtt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_websocket_unknown_message_type(test_transcript, chat_appserver):
|
||||
"""Test unknown message types are echoed back."""
|
||||
server, host, port = chat_appserver
|
||||
base_url = f"ws://{host}:{port}/v1"
|
||||
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||
# Send unknown message type
|
||||
await ws.send_json({"type": "unknown", "data": "test"})
|
||||
|
||||
# Should receive echo
|
||||
response = await ws.receive_json()
|
||||
assert response["type"] == "echo"
|
||||
assert response["data"]["type"] == "unknown"
|
||||
@@ -115,9 +115,7 @@ def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker)
|
||||
settings.DATA_DIR = DATA_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_includes():
|
||||
return ["reflector.pipelines.main_live_pipeline"]
|
||||
# Using celery_includes from conftest.py which includes both pipelines
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
|
||||
@@ -56,7 +56,12 @@ def appserver_ws_user(setup_database):
|
||||
|
||||
if server_instance:
|
||||
server_instance.should_exit = True
|
||||
server_thread.join(timeout=30)
|
||||
server_thread.join(timeout=2.0)
|
||||
|
||||
# Reset global singleton for test isolation
|
||||
from reflector.ws_manager import reset_ws_manager
|
||||
|
||||
reset_ws_manager()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -133,6 +138,11 @@ async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user
|
||||
|
||||
# Connect and then trigger an event via HTTP create
|
||||
async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws:
|
||||
# Give Redis pubsub time to establish subscription before publishing
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Emit an event to the user's room via a standard HTTP action
|
||||
from httpx import AsyncClient
|
||||
|
||||
@@ -150,6 +160,7 @@ async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user
|
||||
"email": "user-abc@example.com",
|
||||
}
|
||||
|
||||
# Use in-memory client (global singleton makes it share ws_manager)
|
||||
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
|
||||
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
|
||||
resp = await ac.post("/transcripts", json={"name": "WS Test"})
|
||||
|
||||
20
server/uv.lock
generated
20
server/uv.lock
generated
@@ -330,26 +330,6 @@ name = "av"
|
||||
version = "14.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/86/f6/0b473dab52dfdea05f28f3578b1c56b6c796ce85e76951bab7c4e38d5a74/av-14.4.0.tar.gz", hash = "sha256:3ecbf803a7fdf67229c0edada0830d6bfaea4d10bfb24f0c3f4e607cd1064b42", size = 3892203 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/18/8a/d57418b686ffd05fabd5a0a9cfa97e63b38c35d7101af00e87c51c8cc43c/av-14.4.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5b21d5586a88b9fce0ab78e26bd1c38f8642f8e2aad5b35e619f4d202217c701", size = 19965048 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/aa/3f878b0301efe587e9b07bb773dd6b47ef44ca09a3cffb4af50c08a170f3/av-14.4.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:cf8762d90b0f94a20c9f6e25a94f1757db5a256707964dfd0b1d4403e7a16835", size = 23750064 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/b4/6fe94a31f9ed3a927daa72df67c7151968587106f30f9f8fcd792b186633/av-14.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0ac9f08920c7bbe0795319689d901e27cb3d7870b9a0acae3f26fc9daa801a6", size = 33648775 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/f3/7f3130753521d779450c935aec3f4beefc8d4645471159f27b54e896470c/av-14.4.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a56d9ad2afdb638ec0404e962dc570960aae7e08ae331ad7ff70fbe99a6cf40e", size = 32216915 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/9a/8ffabfcafb42154b4b3a67d63f9b69e68fa8c34cb39ddd5cb813dd049ed4/av-14.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bed513cbcb3437d0ae47743edc1f5b4a113c0b66cdd4e1aafc533abf5b2fbf2", size = 35287279 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/11/7023ba0a2ca94a57aedf3114ab8cfcecb0819b50c30982a4c5be4d31df41/av-14.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d030c2d3647931e53d51f2f6e0fcf465263e7acf9ec6e4faa8dbfc77975318c3", size = 36294683 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/fa/b8ac9636bd5034e2b899354468bef9f4dadb067420a16d8a493a514b7817/av-14.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1cc21582a4f606271d8c2036ec7a6247df0831050306c55cf8a905701d0f0474", size = 34552391 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/29/0db48079c207d1cba7a2783896db5aec3816e17de55942262c244dffbc0f/av-14.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce7c9cd452153d36f1b1478f904ed5f9ab191d76db873bdd3a597193290805d4", size = 37265250 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/55/715858c3feb7efa4d667ce83a829c8e6ee3862e297fb2b568da3f968639d/av-14.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd261e31cc6b43ca722f80656c39934199d8f2eb391e0147e704b6226acebc29", size = 27925845 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/75/b8641653780336c90ba89e5352cac0afa6256a86a150c7703c0b38851c6d/av-14.4.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:a53e682b239dd23b4e3bc9568cfb1168fc629ab01925fdb2e7556eb426339e94", size = 19954125 },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/e6/37fe6fa5853a48d54d749526365780a63a4bc530be6abf2115e3a21e292a/av-14.4.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5aa0b901751a32703fa938d2155d56ce3faf3630e4a48d238b35d2f7e49e5395", size = 23751479 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/75/9a5f0e6bda5f513b62bafd1cff2b495441a8b07ab7fb7b8e62f0c0d1683f/av-14.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3b316fed3597675fe2aacfed34e25fc9d5bb0196dc8c0b014ae5ed4adda48de", size = 33801401 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/c9/e4df32a2ad1cb7f3a112d0ed610c5e43c89da80b63c60d60e3dc23793ec0/av-14.4.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a587b5c5014c3c0e16143a0f8d99874e46b5d0c50db6111aa0b54206b5687c81", size = 32364330 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/f0/64e7444a41817fde49a07d0239c033f7e9280bec4a4bb4784f5c79af95e6/av-14.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d53f75e8ac1ec8877a551c0db32a83c0aaeae719d05285281eaaba211bbc30", size = 35519508 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/a8/a370099daa9033a3b6f9b9bd815304b3d8396907a14d09845f27467ba138/av-14.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c8558cfde79dd8fc92d97c70e0f0fa8c94c7a66f68ae73afdf58598f0fe5e10d", size = 36448593 },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/bb/edb6ceff8fa7259cb6330c51dbfbc98dd1912bd6eb5f7bc05a4bb14a9d6e/av-14.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:455b6410dea0ab2d30234ffb28df7d62ca3cdf10708528e247bec3a4cdcced09", size = 34701485 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/8a/957da1f581aa1faa9a5dfa8b47ca955edb47f2b76b949950933b457bfa1d/av-14.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1661efbe9d975f927b8512d654704223d936f39016fad2ddab00aee7c40f412c", size = 37521981 },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/76/3f1cf0568592f100fd68eb40ed8c491ce95ca3c1378cc2d4c1f6d1bd295d/av-14.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:fbbeef1f421a3461086853d6464ad5526b56ffe8ccb0ab3fd0a1f121dfbf26ad", size = 27925944 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "banks"
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Box, Dialog, Input, IconButton } from "@chakra-ui/react";
|
||||
import { MessageCircle } from "lucide-react";
|
||||
import Markdown from "react-markdown";
|
||||
import "../../styles/markdown.css";
|
||||
import type { Message } from "./useTranscriptChat";
|
||||
|
||||
interface TranscriptChatModalProps {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
messages: Message[];
|
||||
sendMessage: (text: string) => void;
|
||||
isStreaming: boolean;
|
||||
currentStreamingText: string;
|
||||
}
|
||||
|
||||
export function TranscriptChatModal({
|
||||
open,
|
||||
onClose,
|
||||
messages,
|
||||
sendMessage,
|
||||
isStreaming,
|
||||
currentStreamingText,
|
||||
}: TranscriptChatModalProps) {
|
||||
const [input, setInput] = useState("");
|
||||
|
||||
const handleSend = () => {
|
||||
if (!input.trim()) return;
|
||||
sendMessage(input);
|
||||
setInput("");
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog.Root open={open} onOpenChange={(e) => !e.open && onClose()}>
|
||||
<Dialog.Backdrop />
|
||||
<Dialog.Positioner>
|
||||
<Dialog.Content maxW="500px" h="600px">
|
||||
<Dialog.Header>Transcript Chat</Dialog.Header>
|
||||
|
||||
<Dialog.Body overflowY="auto">
|
||||
{messages.map((msg) => (
|
||||
<Box
|
||||
key={msg.id}
|
||||
p={3}
|
||||
mb={2}
|
||||
bg={msg.role === "user" ? "blue.50" : "gray.50"}
|
||||
borderRadius="md"
|
||||
>
|
||||
{msg.role === "user" ? (
|
||||
msg.text
|
||||
) : (
|
||||
<div className="markdown">
|
||||
<Markdown>{msg.text}</Markdown>
|
||||
</div>
|
||||
)}
|
||||
</Box>
|
||||
))}
|
||||
|
||||
{isStreaming && (
|
||||
<Box p={3} bg="gray.50" borderRadius="md">
|
||||
<div className="markdown">
|
||||
<Markdown>{currentStreamingText}</Markdown>
|
||||
</div>
|
||||
<Box as="span" className="animate-pulse">
|
||||
▊
|
||||
</Box>
|
||||
</Box>
|
||||
)}
|
||||
</Dialog.Body>
|
||||
|
||||
<Dialog.Footer>
|
||||
<Input
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyDown={(e) => e.key === "Enter" && handleSend()}
|
||||
placeholder="Ask about transcript..."
|
||||
disabled={isStreaming}
|
||||
/>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog.Positioner>
|
||||
</Dialog.Root>
|
||||
);
|
||||
}
|
||||
|
||||
export function TranscriptChatButton({ onClick }: { onClick: () => void }) {
|
||||
return (
|
||||
<IconButton
|
||||
position="fixed"
|
||||
bottom="24px"
|
||||
right="24px"
|
||||
onClick={onClick}
|
||||
size="lg"
|
||||
colorPalette="blue"
|
||||
borderRadius="full"
|
||||
aria-label="Open chat"
|
||||
>
|
||||
<MessageCircle />
|
||||
</IconButton>
|
||||
);
|
||||
}
|
||||
@@ -18,15 +18,9 @@ import {
|
||||
Skeleton,
|
||||
Text,
|
||||
Spinner,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { useTranscriptGet } from "../../../lib/apiHooks";
|
||||
import { TranscriptStatus } from "../../../lib/transcript";
|
||||
import {
|
||||
TranscriptChatModal,
|
||||
TranscriptChatButton,
|
||||
} from "../TranscriptChatModal";
|
||||
import { useTranscriptChat } from "../useTranscriptChat";
|
||||
|
||||
type TranscriptDetails = {
|
||||
params: Promise<{
|
||||
@@ -59,9 +53,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
||||
const [finalSummaryElement, setFinalSummaryElement] =
|
||||
useState<HTMLDivElement | null>(null);
|
||||
|
||||
const { open, onOpen, onClose } = useDisclosure();
|
||||
const chat = useTranscriptChat(transcriptId);
|
||||
|
||||
useEffect(() => {
|
||||
if (!waiting || !transcript.data) return;
|
||||
|
||||
@@ -128,15 +119,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
||||
|
||||
return (
|
||||
<>
|
||||
<TranscriptChatModal
|
||||
open={open}
|
||||
onClose={onClose}
|
||||
messages={chat.messages}
|
||||
sendMessage={chat.sendMessage}
|
||||
isStreaming={chat.isStreaming}
|
||||
currentStreamingText={chat.currentStreamingText}
|
||||
/>
|
||||
<TranscriptChatButton onClick={onOpen} />
|
||||
<Grid
|
||||
templateColumns="1fr"
|
||||
templateRows="auto minmax(0, 1fr)"
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState, useRef } from "react";
|
||||
import { getSession } from "next-auth/react";
|
||||
import { WEBSOCKET_URL } from "../../lib/apiClient";
|
||||
import { assertExtendedToken } from "../../lib/types";
|
||||
|
||||
export type Message = {
|
||||
id: string;
|
||||
role: "user" | "assistant";
|
||||
text: string;
|
||||
timestamp: Date;
|
||||
};
|
||||
|
||||
export type UseTranscriptChat = {
|
||||
messages: Message[];
|
||||
sendMessage: (text: string) => void;
|
||||
isStreaming: boolean;
|
||||
currentStreamingText: string;
|
||||
};
|
||||
|
||||
export const useTranscriptChat = (transcriptId: string): UseTranscriptChat => {
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [currentStreamingText, setCurrentStreamingText] = useState("");
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const streamingTextRef = useRef<string>("");
|
||||
const isMountedRef = useRef<boolean>(true);
|
||||
|
||||
useEffect(() => {
|
||||
isMountedRef.current = true;
|
||||
|
||||
const connectWebSocket = async () => {
|
||||
const url = `${WEBSOCKET_URL}/v1/transcripts/${transcriptId}/chat`;
|
||||
|
||||
// Get auth token for WebSocket subprotocol
|
||||
let protocols: string[] | undefined;
|
||||
try {
|
||||
const session = await getSession();
|
||||
if (session) {
|
||||
const token = assertExtendedToken(session).accessToken;
|
||||
// Pass token via subprotocol: ["bearer", token]
|
||||
protocols = ["bearer", token];
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("Failed to get auth token for WebSocket:", error);
|
||||
}
|
||||
|
||||
const ws = new WebSocket(url, protocols);
|
||||
wsRef.current = ws;
|
||||
|
||||
ws.onopen = () => {
|
||||
console.log("Chat WebSocket connected");
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
if (!isMountedRef.current) return;
|
||||
|
||||
const msg = JSON.parse(event.data);
|
||||
|
||||
switch (msg.type) {
|
||||
case "token":
|
||||
setIsStreaming(true);
|
||||
streamingTextRef.current += msg.text;
|
||||
setCurrentStreamingText(streamingTextRef.current);
|
||||
break;
|
||||
|
||||
case "done":
|
||||
// CRITICAL: Save the text BEFORE resetting the ref
|
||||
// The setMessages callback may execute later, after ref is reset
|
||||
const finalText = streamingTextRef.current;
|
||||
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: Date.now().toString(),
|
||||
role: "assistant",
|
||||
text: finalText,
|
||||
timestamp: new Date(),
|
||||
},
|
||||
]);
|
||||
streamingTextRef.current = "";
|
||||
setCurrentStreamingText("");
|
||||
setIsStreaming(false);
|
||||
break;
|
||||
|
||||
case "error":
|
||||
console.error("Chat error:", msg.message);
|
||||
setIsStreaming(false);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
console.error("WebSocket error:", error);
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
console.log("Chat WebSocket closed");
|
||||
};
|
||||
};
|
||||
|
||||
connectWebSocket();
|
||||
|
||||
return () => {
|
||||
isMountedRef.current = false;
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
}
|
||||
};
|
||||
}, [transcriptId]);
|
||||
|
||||
const sendMessage = (text: string) => {
|
||||
if (!wsRef.current) return;
|
||||
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: Date.now().toString(),
|
||||
role: "user",
|
||||
text,
|
||||
timestamp: new Date(),
|
||||
},
|
||||
]);
|
||||
|
||||
wsRef.current.send(JSON.stringify({ type: "message", text }));
|
||||
};
|
||||
|
||||
return { messages, sendMessage, isStreaming, currentStreamingText };
|
||||
};
|
||||
Reference in New Issue
Block a user