Compare commits

..

11 Commits

Author SHA1 Message Date
Igor Loskutov
b84fd1fc24 cleanup 2026-01-13 12:46:03 -05:00
Igor Loskutov
3652de9fca md 2026-01-13 12:44:43 -05:00
Igor Loskutov
68df825734 test: fix WebSocket chat tests using async approach
Replaced TestClient-based tests with proper async WebSocket testing
using httpx_ws and threaded server pattern. TestClient has event loop
issues with WebSocket connections that were causing all tests to fail.

Changes:
- Rewrote all WebSocket tests to use aconnect_ws from httpx_ws
- Added chat_appserver fixture using threaded Uvicorn server
- Tests now use separate event loop in server thread
- All 6 tests now pass without asyncio/event loop errors
- Matches existing pattern from test_transcripts_rtc_ws.py

Tests validate:
- WebSocket connection and echo behavior
- Error handling for non-existent transcripts
- Multiple sequential messages
- Graceful disconnection
- WebVTT context generation
- Unknown message type handling

Closes fn-1.8 (End-to-end testing)
2026-01-12 20:17:42 -05:00
Igor Loskutov
8ca5324c1a feat: integrate TranscriptChatModal and button into transcript page 2026-01-12 20:08:59 -05:00
Igor Loskutov
39e0b89e67 feat: add TranscriptChatModal and TranscriptChatButton components 2026-01-12 19:59:01 -05:00
Igor Loskutov
544793a24f chore: mark fn-1.5 as done (Frontend WebSocket hook)
Task fn-1.5 completed - useTranscriptChat React hook already implemented in commit 2dfe82af.

Hook provides:
- WebSocket connection to /v1/transcripts/{id}/chat endpoint
- Token streaming with ref-based accumulation
- Message history management (user + assistant)
- Memory leak prevention with isMountedRef
- TypeScript type safety
- Proper WebSocket lifecycle and cleanup

Updated task documentation with acceptance criteria and evidence.
2026-01-12 19:49:14 -05:00
Igor Loskutov
088451645a chore: mark fn-1.4 as done (WebSocket route registration) 2026-01-12 19:42:26 -05:00
Igor Loskutov
2dfe82afbc feat: add useTranscriptChat WebSocket hook
Task 5: Frontend WebSocket Hook
- Creates React hook for bidirectional chat WebSocket
- Handles token streaming with proper state accumulation
- Manages conversation history (user + assistant messages)
- Prevents memory leaks with isMounted check
- Proper cleanup on unmount
- Type-safe Message interface

Validated:
- No React dependency issues (removed currentStreamingText from deps)
- No stale closure bugs (using ref for streaming text)
- Proper mounted state tracking
- Lint passes with no errors
- TypeScript types correctly defined
- WebSocket cleanup on unmount

~100 lines
2026-01-12 18:44:09 -05:00
Igor Loskutov
b461ebb488 feat: register transcript chat WebSocket route
- Import transcripts_chat router
- Register /v1/transcripts/{id}/chat endpoint
- Completes LLM streaming integration (fn-1.3)
2026-01-12 18:41:11 -05:00
Igor Loskutov
0b5112cabc feat: add LLM streaming integration to transcript chat
Task 3: LLM Streaming Integration

- Import Settings, ChatMessage, MessageRole from llama-index
- Configure LLM with temperature 0.7 on connection
- Build system message with WebVTT transcript context (max 15k chars)
- Initialize conversation history with system message
- Handle 'message' type from client to trigger LLM streaming
- Stream LLM response using Settings.llm.astream_chat()
- Send tokens incrementally via 'token' messages
- Send 'done' message when streaming completes
- Maintain conversation history across multiple messages
- Add error handling with 'error' message type
- Add message protocol validation test

Implements Tasks 3 & 4 from TASKS.md
2026-01-12 18:28:43 -05:00
Igor Loskutov
316f7b316d feat: add WebVTT context generation to chat WebSocket endpoint
- Import topics_to_webvtt_named and recordings controller
- Add _get_is_multitrack helper function
- Generate WebVTT context on WebSocket connection
- Add get_context message type to retrieve WebVTT
- Maintain backward compatibility with echo for other messages
- Add test fixture and test for WebVTT context generation

Implements task fn-1.2: WebVTT context generation for transcript chat
2026-01-12 18:24:47 -05:00
23 changed files with 878 additions and 402 deletions

View File

@@ -1,26 +1,5 @@
# Changelog
## [0.29.0](https://github.com/Monadical-SAS/reflector/compare/v0.28.1...v0.29.0) (2026-01-21)
### Features
* set hatchet as default for multitracks ([#822](https://github.com/Monadical-SAS/reflector/issues/822)) ([c723752](https://github.com/Monadical-SAS/reflector/commit/c723752b7e15aa48a41ad22856f147a5517d3f46))
## [0.28.1](https://github.com/Monadical-SAS/reflector/compare/v0.28.0...v0.28.1) (2026-01-21)
### Bug Fixes
* ics non-sync bugfix ([#823](https://github.com/Monadical-SAS/reflector/issues/823)) ([23d2bc2](https://github.com/Monadical-SAS/reflector/commit/23d2bc283d4d02187b250d2055103e0374ee93d6))
## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20)
### Features
* worker affinity ([#819](https://github.com/Monadical-SAS/reflector/issues/819)) ([3b6540e](https://github.com/Monadical-SAS/reflector/commit/3b6540eae5b597449f98661bdf15483b77be3268))
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)

View File

@@ -34,7 +34,7 @@ services:
environment:
ENTRYPOINT: beat
hatchet-worker-cpu:
hatchet-worker:
build:
context: server
volumes:
@@ -43,20 +43,7 @@ services:
env_file:
- ./server/.env
environment:
ENTRYPOINT: hatchet-worker-cpu
depends_on:
hatchet:
condition: service_healthy
hatchet-worker-llm:
build:
context: server
volumes:
- ./server/:/app/
- /app/.venv
env_file:
- ./server/.env
environment:
ENTRYPOINT: hatchet-worker-llm
ENTRYPOINT: hatchet-worker
depends_on:
hatchet:
condition: service_healthy

View File

@@ -1,44 +0,0 @@
"""replace_use_hatchet_with_use_celery
Revision ID: 80beb1ea3269
Revises: bd3a729bb379
Create Date: 2026-01-20 16:26:25.555869
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "80beb1ea3269"
down_revision: Union[str, None] = "bd3a729bb379"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"use_celery",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
)
)
batch_op.drop_column("use_hatchet")
def downgrade() -> None:
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"use_hatchet",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
)
)
batch_op.drop_column("use_celery")

View File

@@ -18,6 +18,7 @@ from reflector.views.rooms import router as rooms_router
from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_chat import router as transcripts_chat_router
from reflector.views.transcripts_participants import (
router as transcripts_participants_router,
)
@@ -90,6 +91,7 @@ app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_speaker_router, prefix="/v1")
app.include_router(transcripts_upload_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_chat_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(transcripts_process_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")

View File

@@ -58,7 +58,7 @@ rooms = sqlalchemy.Table(
nullable=False,
),
sqlalchemy.Column(
"use_celery",
"use_hatchet",
sqlalchemy.Boolean,
nullable=False,
server_default=false(),
@@ -97,7 +97,7 @@ class Room(BaseModel):
ics_last_sync: datetime | None = None
ics_last_etag: str | None = None
platform: Platform = Field(default_factory=lambda: settings.DEFAULT_VIDEO_PLATFORM)
use_celery: bool = False
use_hatchet: bool = False
skip_consent: bool = False

View File

@@ -0,0 +1,77 @@
"""
Run Hatchet workers for the multitrack pipeline.
Runs as a separate process, just like Celery workers.
Usage:
uv run -m reflector.hatchet.run_workers
# Or via docker:
docker compose exec server uv run -m reflector.hatchet.run_workers
"""
import signal
import sys
from hatchet_sdk.rate_limit import RateLimitDuration
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND
from reflector.logger import logger
from reflector.settings import settings
def main() -> None:
"""Start Hatchet worker polling."""
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting workers")
sys.exit(1)
if not settings.HATCHET_CLIENT_TOKEN:
logger.error("HATCHET_CLIENT_TOKEN is not set")
sys.exit(1)
logger.info(
"Starting Hatchet workers",
debug=settings.HATCHET_DEBUG,
)
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
# Can't use lazy init: decorators need the client object when function is defined.
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows import ( # noqa: PLC0415
daily_multitrack_pipeline,
subject_workflow,
topic_chunk_workflow,
track_workflow,
)
hatchet = HatchetClientManager.get_client()
hatchet.rate_limits.put(
LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND
)
worker = hatchet.worker(
"reflector-pipeline-worker",
workflows=[
daily_multitrack_pipeline,
subject_workflow,
topic_chunk_workflow,
track_workflow,
],
)
def shutdown_handler(signum: int, frame) -> None:
logger.info("Received shutdown signal, stopping workers...")
# Worker cleanup happens automatically on exit
sys.exit(0)
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
logger.info("Starting Hatchet worker polling...")
worker.start()
if __name__ == "__main__":
main()

View File

@@ -1,43 +0,0 @@
"""
CPU-heavy worker pool for audio processing tasks.
Handles ONLY: mixdown_tracks
Configuration:
- slots=1: Only mixdown (already serialized globally with max_runs=1)
- Worker affinity: pool=cpu-heavy
"""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline,
)
from reflector.logger import logger
def main():
hatchet = HatchetClientManager.get_client()
logger.info(
"Starting Hatchet CPU worker pool (mixdown only)",
worker_name="cpu-worker-pool",
slots=1,
labels={"pool": "cpu-heavy"},
)
cpu_worker = hatchet.worker(
"cpu-worker-pool",
slots=1, # Only 1 mixdown at a time (already serialized globally)
labels={
"pool": "cpu-heavy",
},
workflows=[daily_multitrack_pipeline],
)
try:
cpu_worker.start()
except KeyboardInterrupt:
logger.info("Received shutdown signal, stopping CPU workers...")
if __name__ == "__main__":
main()

View File

@@ -1,51 +0,0 @@
"""
LLM/I/O worker pool for all non-CPU tasks.
Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration)
"""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline,
)
from reflector.hatchet.workflows.subject_processing import subject_workflow
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
from reflector.hatchet.workflows.track_processing import track_workflow
from reflector.logger import logger
SLOTS = 10
WORKER_NAME = "llm-worker-pool"
POOL = "llm-io"
def main():
hatchet = HatchetClientManager.get_client()
logger.info(
"Starting Hatchet LLM worker pool (all tasks except mixdown)",
worker_name=WORKER_NAME,
slots=SLOTS,
labels={"pool": POOL},
)
llm_worker = hatchet.worker(
WORKER_NAME,
slots=SLOTS, # not all slots are probably used
labels={
"pool": POOL,
},
workflows=[
daily_multitrack_pipeline,
topic_chunk_workflow,
subject_workflow,
track_workflow,
],
)
try:
llm_worker.start()
except KeyboardInterrupt:
logger.info("Received shutdown signal, stopping LLM workers...")
if __name__ == "__main__":
main()

View File

@@ -23,12 +23,7 @@ from pathlib import Path
from typing import Any, Callable, Coroutine, Protocol, TypeVar
import httpx
from hatchet_sdk import (
ConcurrencyExpression,
ConcurrencyLimitStrategy,
Context,
)
from hatchet_sdk.labels import DesiredWorkerLabel
from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.dailyco_api.client import DailyApiClient
@@ -472,20 +467,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
parents=[process_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
retries=3,
desired_worker_labels={
"pool": DesiredWorkerLabel(
value="cpu-heavy",
required=True,
weight=100,
),
},
concurrency=[
ConcurrencyExpression(
expression="'mixdown-global'",
max_runs=1, # serialize mixdown to prevent resource contention
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue
)
],
)
@with_error_handling(TaskName.MIXDOWN_TRACKS)
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:

View File

@@ -7,11 +7,7 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
from datetime import timedelta
from hatchet_sdk import (
ConcurrencyExpression,
ConcurrencyLimitStrategy,
Context,
)
from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context
from hatchet_sdk.rate_limit import RateLimit
from pydantic import BaseModel
@@ -38,13 +34,11 @@ hatchet = HatchetClientManager.get_client()
topic_chunk_workflow = hatchet.workflow(
name="TopicChunkProcessing",
input_validator=TopicChunkInput,
concurrency=[
ConcurrencyExpression(
expression="'global'", # constant string = global limit across all runs
max_runs=20,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
)
],
concurrency=ConcurrencyExpression(
expression="'global'", # constant string = global limit across all runs
max_runs=20,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
),
)

View File

@@ -319,6 +319,21 @@ class ICSSyncService:
calendar = self.fetch_service.parse_ics(ics_content)
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
if room.ics_last_etag == content_hash:
logger.info("No changes in ICS for room", room_id=room.id)
room_url = f"{settings.UI_BASE_URL}/{room.name}"
events, total_events = self.fetch_service.extract_room_events(
calendar, room.name, room_url
)
return {
"status": SyncStatus.UNCHANGED,
"hash": content_hash,
"events_found": len(events),
"total_events": total_events,
"events_created": 0,
"events_updated": 0,
"events_deleted": 0,
}
# Extract matching events
room_url = f"{settings.UI_BASE_URL}/{room.name}"
@@ -356,44 +371,6 @@ class ICSSyncService:
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
return time_since_sync.total_seconds() >= room.ics_fetch_interval
def _event_data_changed(self, existing: CalendarEvent, new_data: EventData) -> bool:
"""Check if event data has changed by comparing relevant fields.
IMPORTANT: When adding fields to CalendarEvent/EventData, update this method
and the _COMPARED_FIELDS set below for runtime validation.
"""
# Fields that come from ICS and should trigger updates when changed
_COMPARED_FIELDS = {
"title",
"description",
"start_time",
"end_time",
"location",
"attendees",
"ics_raw_data",
}
# Runtime exhaustiveness check: ensure we're comparing all EventData fields
event_data_fields = set(EventData.__annotations__.keys()) - {"ics_uid"}
if event_data_fields != _COMPARED_FIELDS:
missing = event_data_fields - _COMPARED_FIELDS
extra = _COMPARED_FIELDS - event_data_fields
raise RuntimeError(
f"_event_data_changed() field mismatch: "
f"missing={missing}, extra={extra}. "
f"Update the comparison logic when adding/removing fields."
)
return (
existing.title != new_data["title"]
or existing.description != new_data["description"]
or existing.start_time != new_data["start_time"]
or existing.end_time != new_data["end_time"]
or existing.location != new_data["location"]
or existing.attendees != new_data["attendees"]
or existing.ics_raw_data != new_data["ics_raw_data"]
)
async def _sync_events_to_database(
self, room_id: str, events: list[EventData]
) -> SyncStats:
@@ -409,14 +386,11 @@ class ICSSyncService:
)
if existing:
# Only count as updated if data actually changed
if self._event_data_changed(existing, event_data):
updated += 1
await calendar_events_controller.upsert(calendar_event)
updated += 1
else:
created += 1
await calendar_events_controller.upsert(calendar_event)
await calendar_events_controller.upsert(calendar_event)
current_ics_uids.append(event_data["ics_uid"])
# Soft delete events that are no longer in calendar

View File

@@ -23,6 +23,7 @@ from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_multitrack_pipeline import (
task_pipeline_multitrack_process,
)
from reflector.settings import settings
from reflector.utils.string import NonEmptyString
@@ -101,8 +102,8 @@ async def validate_transcript_for_processing(
if transcript.locked:
return ValidationLocked(detail="Recording is locked")
# Check if recording is ready for processing
if transcript.status == "idle" and not transcript.workflow_run_id:
# hatchet is idempotent anyways + if it wasn't dispatched successfully
if transcript.status == "idle" and not settings.HATCHET_ENABLED:
return ValidationNotReady(detail="Recording is not ready for processing")
# Check Celery tasks
@@ -115,8 +116,7 @@ async def validate_transcript_for_processing(
):
return ValidationAlreadyScheduled(detail="already running")
# Check Hatchet workflow status if workflow_run_id exists
if transcript.workflow_run_id:
if settings.HATCHET_ENABLED and transcript.workflow_run_id:
try:
status = await HatchetClientManager.get_workflow_run_status(
transcript.workflow_run_id
@@ -181,16 +181,19 @@ async def dispatch_transcript_processing(
Returns AsyncResult for Celery tasks, None for Hatchet workflows.
"""
if isinstance(config, MultitrackProcessingConfig):
use_celery = False
# Check if room has use_hatchet=True (overrides env vars)
room_forces_hatchet = False
if config.room_id:
room = await rooms_controller.get_by_id(config.room_id)
use_celery = room.use_celery if room else False
room_forces_hatchet = room.use_hatchet if room else False
use_hatchet = not use_celery
# Start durable workflow if enabled (Hatchet)
# and if room has use_hatchet=True
use_hatchet = settings.HATCHET_ENABLED and room_forces_hatchet
if use_celery:
if room_forces_hatchet:
logger.info(
"Room uses legacy Celery processing",
"Room forces Hatchet workflow",
room_id=config.room_id,
transcript_id=config.transcript_id,
)

View File

@@ -158,10 +158,19 @@ class Settings(BaseSettings):
ZULIP_API_KEY: str | None = None
ZULIP_BOT_EMAIL: str | None = None
# Hatchet workflow orchestration (always enabled for multitrack processing)
# Durable workflow orchestration
# Provider: "hatchet" (or "none" to disable)
DURABLE_WORKFLOW_PROVIDER: str = "none"
# Hatchet workflow orchestration
HATCHET_CLIENT_TOKEN: str | None = None
HATCHET_CLIENT_TLS_STRATEGY: str = "none" # none, tls, mtls
HATCHET_DEBUG: bool = False
@property
def HATCHET_ENABLED(self) -> bool:
"""True if Hatchet is the active provider."""
return self.DURABLE_WORKFLOW_PROVIDER == "hatchet"
settings = Settings()

View File

@@ -0,0 +1,133 @@
"""
Transcripts chat API
====================
WebSocket endpoint for bidirectional chat with LLM about transcript content.
"""
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from reflector.auth.auth_jwt import JWTAuth
from reflector.db.recordings import recordings_controller
from reflector.db.transcripts import transcripts_controller
from reflector.db.users import user_controller
from reflector.llm import LLM
from reflector.settings import settings
from reflector.utils.transcript_formats import topics_to_webvtt_named
router = APIRouter()
async def _get_is_multitrack(transcript) -> bool:
"""Detect if transcript is from multitrack recording."""
if not transcript.recording_id:
return False
recording = await recordings_controller.get_by_id(transcript.recording_id)
return recording is not None and recording.is_multitrack
@router.websocket("/transcripts/{transcript_id}/chat")
async def transcript_chat_websocket(
transcript_id: str,
websocket: WebSocket,
):
"""WebSocket endpoint for chatting with LLM about transcript content."""
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
token: Optional[str] = None
negotiated_subprotocol: Optional[str] = None
if len(parts) >= 2 and parts[0].lower() == "bearer":
negotiated_subprotocol = "bearer"
token = parts[1]
user_id: Optional[str] = None
if token:
try:
payload = JWTAuth().verify_token(token)
authentik_uid = payload.get("sub")
if authentik_uid:
user = await user_controller.get_by_authentik_uid(authentik_uid)
if user:
user_id = user.id
except Exception:
# Auth failed - continue as anonymous
pass
# Get transcript (respects user_id for private transcripts)
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
return
# 2. Accept connection (with negotiated subprotocol if present)
await websocket.accept(subprotocol=negotiated_subprotocol)
# 3. Generate WebVTT context
is_multitrack = await _get_is_multitrack(transcript)
webvtt = topics_to_webvtt_named(
transcript.topics, transcript.participants, is_multitrack
)
# Truncate if needed (15k char limit for POC)
webvtt_truncated = webvtt[:15000] if len(webvtt) > 15000 else webvtt
# 4. Configure LLM
llm = LLM(settings=settings, temperature=0.7)
# 5. System message with transcript context
system_msg = f"""You are analyzing this meeting transcript (WebVTT):
{webvtt_truncated}
Answer questions about content, speakers, timeline. Include timestamps when relevant."""
# 6. Conversation history
conversation_history = [ChatMessage(role=MessageRole.SYSTEM, content=system_msg)]
try:
# 7. Message loop
while True:
data = await websocket.receive_json()
if data.get("type") == "get_context":
# Return WebVTT context (for debugging/testing)
await websocket.send_json({"type": "context", "webvtt": webvtt})
continue
if data.get("type") != "message":
# Echo unknown types for backward compatibility
await websocket.send_json({"type": "echo", "data": data})
continue
# Add user message to history
user_msg = ChatMessage(role=MessageRole.USER, content=data.get("text", ""))
conversation_history.append(user_msg)
# Stream LLM response
assistant_msg = ""
chat_stream = await Settings.llm.astream_chat(conversation_history)
async for chunk in chat_stream:
token = chunk.delta or ""
if token:
await websocket.send_json({"type": "token", "text": token})
assistant_msg += token
# Save assistant response to history
conversation_history.append(
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_msg)
)
await websocket.send_json({"type": "done"})
except WebSocketDisconnect:
pass
except Exception as e:
await websocket.send_json({"type": "error", "message": str(e)})

View File

@@ -287,12 +287,11 @@ async def _process_multitrack_recording_inner(
room_id=room.id,
)
use_celery = room and room.use_celery
use_hatchet = not use_celery
use_hatchet = settings.HATCHET_ENABLED and room and room.use_hatchet
if use_celery:
if room and room.use_hatchet and not settings.HATCHET_ENABLED:
logger.info(
"Room uses legacy Celery processing",
"Room forces Hatchet workflow",
room_id=room.id,
transcript_id=transcript.id,
)
@@ -811,6 +810,7 @@ async def reprocess_failed_daily_recordings():
)
continue
# Fetch room to check use_hatchet flag
room = None
if meeting.room_id:
room = await rooms_controller.get_by_id(meeting.room_id)
@@ -834,10 +834,10 @@ async def reprocess_failed_daily_recordings():
)
continue
use_celery = room and room.use_celery
use_hatchet = not use_celery
use_hatchet = settings.HATCHET_ENABLED and room and room.use_hatchet
if use_hatchet:
# Hatchet requires a transcript for workflow_run_id tracking
if not transcript:
logger.warning(
"No transcript for Hatchet reprocessing, skipping",

View File

@@ -7,10 +7,8 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then
uv run celery -A reflector.worker.app worker --loglevel=info
elif [ "${ENTRYPOINT}" = "beat" ]; then
uv run celery -A reflector.worker.app beat --loglevel=info
elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then
uv run python -m reflector.hatchet.run_workers_cpu
elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then
uv run python -m reflector.hatchet.run_workers_llm
elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then
uv run python -m reflector.hatchet.run_workers
else
echo "Unknown command"
fi

View File

@@ -2,9 +2,10 @@
Tests for Hatchet workflow dispatch and routing logic.
These tests verify:
1. Hatchet workflow validation and replay logic
2. Force flag to cancel and restart workflows
3. Validation prevents concurrent workflows
1. Routing to Hatchet when HATCHET_ENABLED=True
2. Replay logic for failed workflows
3. Force flag to cancel and restart
4. Validation prevents concurrent workflows
"""
from unittest.mock import AsyncMock, patch
@@ -33,22 +34,25 @@ async def test_hatchet_validation_blocks_running_workflow():
workflow_run_id="running-workflow-123",
)
with patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.RUNNING
)
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = True
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.RUNNING
)
result = await validate_transcript_for_processing(mock_transcript)
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationAlreadyScheduled)
assert "running" in result.detail.lower()
result = await validate_transcript_for_processing(mock_transcript)
assert isinstance(result, ValidationAlreadyScheduled)
assert "running" in result.detail.lower()
@pytest.mark.usefixtures("setup_database")
@@ -68,21 +72,24 @@ async def test_hatchet_validation_blocks_queued_workflow():
workflow_run_id="queued-workflow-123",
)
with patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.QUEUED
)
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = True
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.QUEUED
)
result = await validate_transcript_for_processing(mock_transcript)
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationAlreadyScheduled)
result = await validate_transcript_for_processing(mock_transcript)
assert isinstance(result, ValidationAlreadyScheduled)
@pytest.mark.usefixtures("setup_database")
@@ -103,22 +110,25 @@ async def test_hatchet_validation_allows_failed_workflow():
recording_id="test-recording-id",
)
with patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.FAILED
)
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = True
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.FAILED
)
result = await validate_transcript_for_processing(mock_transcript)
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationOk)
assert result.transcript_id == "test-transcript-id"
result = await validate_transcript_for_processing(mock_transcript)
assert isinstance(result, ValidationOk)
assert result.transcript_id == "test-transcript-id"
@pytest.mark.usefixtures("setup_database")
@@ -139,21 +149,24 @@ async def test_hatchet_validation_allows_completed_workflow():
recording_id="test-recording-id",
)
with patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.COMPLETED
)
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = True
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.COMPLETED
)
result = await validate_transcript_for_processing(mock_transcript)
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationOk)
result = await validate_transcript_for_processing(mock_transcript)
assert isinstance(result, ValidationOk)
@pytest.mark.usefixtures("setup_database")
@@ -174,23 +187,26 @@ async def test_hatchet_validation_allows_when_status_check_fails():
recording_id="test-recording-id",
)
with patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
# Status check fails (workflow might be deleted)
mock_hatchet.get_workflow_run_status = AsyncMock(
side_effect=ApiException("Workflow not found")
)
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = True
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
# Status check fails (workflow might be deleted)
mock_hatchet.get_workflow_run_status = AsyncMock(
side_effect=ApiException("Workflow not found")
)
result = await validate_transcript_for_processing(mock_transcript)
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
# Should allow processing when we can't get status
assert isinstance(result, ValidationOk)
result = await validate_transcript_for_processing(mock_transcript)
# Should allow processing when we can't get status
assert isinstance(result, ValidationOk)
@pytest.mark.usefixtures("setup_database")
@@ -211,11 +227,47 @@ async def test_hatchet_validation_skipped_when_no_workflow_id():
recording_id="test-recording-id",
)
with patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
# Should not be called
mock_hatchet.get_workflow_run_status = AsyncMock()
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = True
with patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
# Should not be called
mock_hatchet.get_workflow_run_status = AsyncMock()
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
result = await validate_transcript_for_processing(mock_transcript)
# Should not check Hatchet status
mock_hatchet.get_workflow_run_status.assert_not_called()
assert isinstance(result, ValidationOk)
@pytest.mark.usefixtures("setup_database")
@pytest.mark.asyncio
async def test_hatchet_validation_skipped_when_disabled():
"""Test that Hatchet validation is skipped when HATCHET_ENABLED is False."""
from reflector.services.transcript_process import (
ValidationOk,
validate_transcript_for_processing,
)
mock_transcript = Transcript(
id="test-transcript-id",
name="Test",
status="uploaded",
source_kind="room",
workflow_run_id="some-workflow-123",
recording_id="test-recording-id",
)
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = False # Hatchet disabled
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
@@ -224,8 +276,7 @@ async def test_hatchet_validation_skipped_when_no_workflow_id():
result = await validate_transcript_for_processing(mock_transcript)
# Should not check Hatchet status
mock_hatchet.get_workflow_run_status.assert_not_called()
# Should not check Hatchet at all
assert isinstance(result, ValidationOk)

View File

@@ -189,17 +189,14 @@ async def test_ics_sync_service_sync_room_calendar():
assert events[0].ics_uid == "sync-event-1"
assert events[0].title == "Sync Test Meeting"
# Second sync with same content (calendar unchanged, but sync always runs)
# Second sync with same content (should be unchanged)
# Refresh room to get updated etag and force sync by setting old sync time
room = await rooms_controller.get_by_id(room.id)
await rooms_controller.update(
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
)
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success"
assert result["events_created"] == 0
assert result["events_updated"] == 0
assert result["events_deleted"] == 0
assert result["status"] == "unchanged"
# Third sync with updated event
event["summary"] = "Updated Meeting Title"
@@ -291,43 +288,3 @@ async def test_ics_sync_service_error_handling():
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "error"
assert "Network error" in result["error"]
@pytest.mark.asyncio
async def test_event_data_changed_exhaustiveness():
"""Test that _event_data_changed compares all EventData fields (except ics_uid).
This test ensures programmers don't forget to update the comparison logic
when adding new fields to EventData/CalendarEvent.
"""
from reflector.services.ics_sync import EventData
sync_service = ICSSyncService()
from reflector.db.calendar_events import CalendarEvent
now = datetime.now(timezone.utc)
event_data: EventData = {
"ics_uid": "test-123",
"title": "Test",
"description": "Desc",
"location": "Loc",
"start_time": now,
"end_time": now + timedelta(hours=1),
"attendees": [],
"ics_raw_data": "raw",
}
existing = CalendarEvent(
room_id="room1",
**event_data,
)
# Will raise RuntimeError if fields are missing from comparison
result = sync_service._event_data_changed(existing, event_data)
assert result is False
modified_data = event_data.copy()
modified_data["title"] = "Changed Title"
result = sync_service._event_data_changed(existing, modified_data)
assert result is True

View File

@@ -0,0 +1,234 @@
"""Tests for transcript chat WebSocket endpoint."""
import asyncio
import threading
import time
from pathlib import Path
import pytest
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
from reflector.db.transcripts import (
SourceKind,
TranscriptParticipant,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Word
@pytest.fixture
def chat_appserver(tmpdir, setup_database):
"""Start a real HTTP server for WebSocket testing."""
from reflector.app import app
from reflector.db import get_database
from reflector.settings import settings
DATA_DIR = settings.DATA_DIR
settings.DATA_DIR = Path(tmpdir)
# Start server in separate thread with its own event loop
host = "127.0.0.1"
port = 1256 # Different port from rtc tests
server_started = threading.Event()
server_exception = None
server_instance = None
def run_server():
nonlocal server_exception, server_instance
try:
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
# Initialize database connection in this event loop
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait for server to start
server_started.wait(timeout=30)
if server_exception:
raise server_exception
# Wait for server to be fully ready
time.sleep(1)
yield server_instance, host, port
# Stop server
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
settings.DATA_DIR = DATA_DIR
@pytest.fixture
async def test_transcript(setup_database):
"""Create a test transcript for WebSocket tests."""
transcript = await transcripts_controller.add(
name="Test Transcript for Chat", source_kind=SourceKind.FILE
)
return transcript
@pytest.fixture
async def test_transcript_with_content(setup_database):
"""Create a test transcript with actual content for WebVTT generation."""
transcript = await transcripts_controller.add(
name="Test Transcript with Content", source_kind=SourceKind.FILE
)
# Add participants
await transcripts_controller.update(
transcript,
{
"participants": [
TranscriptParticipant(id="1", speaker=0, name="Alice").model_dump(),
TranscriptParticipant(id="2", speaker=1, name="Bob").model_dump(),
]
},
)
# Add topic with words
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Introduction",
summary="Opening remarks",
timestamp=0.0,
words=[
Word(text="Hello ", start=0.0, end=1.0, speaker=0),
Word(text="everyone.", start=1.0, end=2.0, speaker=0),
Word(text="Hi ", start=2.0, end=3.0, speaker=1),
Word(text="there!", start=3.0, end=4.0, speaker=1),
],
),
)
return transcript
@pytest.mark.asyncio
async def test_chat_websocket_connection_success(test_transcript, chat_appserver):
"""Test successful WebSocket connection to chat endpoint."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send unknown message type to test echo behavior
await ws.send_json({"type": "test", "text": "Hello"})
# Should receive echo for unknown types
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == "test"
@pytest.mark.asyncio
async def test_chat_websocket_nonexistent_transcript(chat_appserver):
"""Test WebSocket connection fails for nonexistent transcript."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
# Connection should fail or disconnect immediately for non-existent transcript
# Different behavior from successful connection
with pytest.raises(Exception): # Will raise on connection or first operation
async with aconnect_ws(f"{base_url}/transcripts/nonexistent-id/chat") as ws:
await ws.send_json({"type": "message", "text": "Hello"})
await ws.receive_json()
@pytest.mark.asyncio
async def test_chat_websocket_multiple_messages(test_transcript, chat_appserver):
"""Test sending multiple messages through WebSocket."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send multiple unknown message types (testing echo behavior)
messages = ["First message", "Second message", "Third message"]
for i, msg in enumerate(messages):
await ws.send_json({"type": f"test{i}", "text": msg})
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == f"test{i}"
assert response["data"]["text"] == msg
@pytest.mark.asyncio
async def test_chat_websocket_disconnect_graceful(test_transcript, chat_appserver):
"""Test WebSocket disconnects gracefully."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
await ws.send_json({"type": "message", "text": "Hello"})
await ws.receive_json()
# Close handled by context manager - should not raise
@pytest.mark.asyncio
async def test_chat_websocket_context_generation(
test_transcript_with_content, chat_appserver
):
"""Test WebVTT context is generated on connection."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(
f"{base_url}/transcripts/{test_transcript_with_content.id}/chat"
) as ws:
# Request context
await ws.send_json({"type": "get_context"})
# Receive context response
response = await ws.receive_json()
assert response["type"] == "context"
assert "webvtt" in response
# Verify WebVTT format
webvtt = response["webvtt"]
assert webvtt.startswith("WEBVTT")
assert "<v Alice>" in webvtt
assert "<v Bob>" in webvtt
assert "Hello everyone." in webvtt
assert "Hi there!" in webvtt
@pytest.mark.asyncio
async def test_chat_websocket_unknown_message_type(test_transcript, chat_appserver):
"""Test unknown message types are echoed back."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send unknown message type
await ws.send_json({"type": "unknown", "data": "test"})
# Should receive echo
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == "unknown"

View File

@@ -162,24 +162,9 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client):
from datetime import datetime, timezone
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
# Force Celery backend for test
await rooms_controller.update(room, {"use_celery": True})
# Create transcript with Daily.co multitrack recording
transcript = await transcripts_controller.add(
"",
source_kind="room",
@@ -187,7 +172,6 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client):
target_language="en",
user_id="test-user",
share_mode="public",
room_id=room.id,
)
track_keys = [

View File

@@ -0,0 +1,103 @@
"use client";
import { useState } from "react";
import { Box, Dialog, Input, IconButton } from "@chakra-ui/react";
import { MessageCircle } from "lucide-react";
import Markdown from "react-markdown";
import "../../styles/markdown.css";
import type { Message } from "./useTranscriptChat";
interface TranscriptChatModalProps {
open: boolean;
onClose: () => void;
messages: Message[];
sendMessage: (text: string) => void;
isStreaming: boolean;
currentStreamingText: string;
}
export function TranscriptChatModal({
open,
onClose,
messages,
sendMessage,
isStreaming,
currentStreamingText,
}: TranscriptChatModalProps) {
const [input, setInput] = useState("");
const handleSend = () => {
if (!input.trim()) return;
sendMessage(input);
setInput("");
};
return (
<Dialog.Root open={open} onOpenChange={(e) => !e.open && onClose()}>
<Dialog.Backdrop />
<Dialog.Positioner>
<Dialog.Content maxW="500px" h="600px">
<Dialog.Header>Transcript Chat</Dialog.Header>
<Dialog.Body overflowY="auto">
{messages.map((msg) => (
<Box
key={msg.id}
p={3}
mb={2}
bg={msg.role === "user" ? "blue.50" : "gray.50"}
borderRadius="md"
>
{msg.role === "user" ? (
msg.text
) : (
<div className="markdown">
<Markdown>{msg.text}</Markdown>
</div>
)}
</Box>
))}
{isStreaming && (
<Box p={3} bg="gray.50" borderRadius="md">
<div className="markdown">
<Markdown>{currentStreamingText}</Markdown>
</div>
<Box as="span" className="animate-pulse">
</Box>
</Box>
)}
</Dialog.Body>
<Dialog.Footer>
<Input
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleSend()}
placeholder="Ask about transcript..."
disabled={isStreaming}
/>
</Dialog.Footer>
</Dialog.Content>
</Dialog.Positioner>
</Dialog.Root>
);
}
export function TranscriptChatButton({ onClick }: { onClick: () => void }) {
return (
<IconButton
position="fixed"
bottom="24px"
right="24px"
onClick={onClick}
size="lg"
colorPalette="blue"
borderRadius="full"
aria-label="Open chat"
>
<MessageCircle />
</IconButton>
);
}

View File

@@ -18,9 +18,15 @@ import {
Skeleton,
Text,
Spinner,
useDisclosure,
} from "@chakra-ui/react";
import { useTranscriptGet } from "../../../lib/apiHooks";
import { TranscriptStatus } from "../../../lib/transcript";
import {
TranscriptChatModal,
TranscriptChatButton,
} from "../TranscriptChatModal";
import { useTranscriptChat } from "../useTranscriptChat";
type TranscriptDetails = {
params: Promise<{
@@ -53,6 +59,9 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const [finalSummaryElement, setFinalSummaryElement] =
useState<HTMLDivElement | null>(null);
const { open, onOpen, onClose } = useDisclosure();
const chat = useTranscriptChat(transcriptId);
useEffect(() => {
if (!waiting || !transcript.data) return;
@@ -119,6 +128,15 @@ export default function TranscriptDetails(details: TranscriptDetails) {
return (
<>
<TranscriptChatModal
open={open}
onClose={onClose}
messages={chat.messages}
sendMessage={chat.sendMessage}
isStreaming={chat.isStreaming}
currentStreamingText={chat.currentStreamingText}
/>
<TranscriptChatButton onClick={onOpen} />
<Grid
templateColumns="1fr"
templateRows="auto minmax(0, 1fr)"

View File

@@ -0,0 +1,130 @@
"use client";
import { useEffect, useState, useRef } from "react";
import { getSession } from "next-auth/react";
import { WEBSOCKET_URL } from "../../lib/apiClient";
import { assertExtendedToken } from "../../lib/types";
export type Message = {
id: string;
role: "user" | "assistant";
text: string;
timestamp: Date;
};
export type UseTranscriptChat = {
messages: Message[];
sendMessage: (text: string) => void;
isStreaming: boolean;
currentStreamingText: string;
};
export const useTranscriptChat = (transcriptId: string): UseTranscriptChat => {
const [messages, setMessages] = useState<Message[]>([]);
const [isStreaming, setIsStreaming] = useState(false);
const [currentStreamingText, setCurrentStreamingText] = useState("");
const wsRef = useRef<WebSocket | null>(null);
const streamingTextRef = useRef<string>("");
const isMountedRef = useRef<boolean>(true);
useEffect(() => {
isMountedRef.current = true;
const connectWebSocket = async () => {
const url = `${WEBSOCKET_URL}/v1/transcripts/${transcriptId}/chat`;
// Get auth token for WebSocket subprotocol
let protocols: string[] | undefined;
try {
const session = await getSession();
if (session) {
const token = assertExtendedToken(session).accessToken;
// Pass token via subprotocol: ["bearer", token]
protocols = ["bearer", token];
}
} catch (error) {
console.warn("Failed to get auth token for WebSocket:", error);
}
const ws = new WebSocket(url, protocols);
wsRef.current = ws;
ws.onopen = () => {
console.log("Chat WebSocket connected");
};
ws.onmessage = (event) => {
if (!isMountedRef.current) return;
const msg = JSON.parse(event.data);
switch (msg.type) {
case "token":
setIsStreaming(true);
streamingTextRef.current += msg.text;
setCurrentStreamingText(streamingTextRef.current);
break;
case "done":
// CRITICAL: Save the text BEFORE resetting the ref
// The setMessages callback may execute later, after ref is reset
const finalText = streamingTextRef.current;
setMessages((prev) => [
...prev,
{
id: Date.now().toString(),
role: "assistant",
text: finalText,
timestamp: new Date(),
},
]);
streamingTextRef.current = "";
setCurrentStreamingText("");
setIsStreaming(false);
break;
case "error":
console.error("Chat error:", msg.message);
setIsStreaming(false);
break;
}
};
ws.onerror = (error) => {
console.error("WebSocket error:", error);
};
ws.onclose = () => {
console.log("Chat WebSocket closed");
};
};
connectWebSocket();
return () => {
isMountedRef.current = false;
if (wsRef.current) {
wsRef.current.close();
}
};
}, [transcriptId]);
const sendMessage = (text: string) => {
if (!wsRef.current) return;
setMessages((prev) => [
...prev,
{
id: Date.now().toString(),
role: "user",
text,
timestamp: new Date(),
},
]);
wsRef.current.send(JSON.stringify({ type: "message", text }));
};
return { messages, sendMessage, isStreaming, currentStreamingText };
};