Compare commits

..

11 Commits

Author SHA1 Message Date
Igor Loskutov
b84fd1fc24 cleanup 2026-01-13 12:46:03 -05:00
Igor Loskutov
3652de9fca md 2026-01-13 12:44:43 -05:00
Igor Loskutov
68df825734 test: fix WebSocket chat tests using async approach
Replaced TestClient-based tests with proper async WebSocket testing
using httpx_ws and threaded server pattern. TestClient has event loop
issues with WebSocket connections that were causing all tests to fail.

Changes:
- Rewrote all WebSocket tests to use aconnect_ws from httpx_ws
- Added chat_appserver fixture using threaded Uvicorn server
- Tests now use separate event loop in server thread
- All 6 tests now pass without asyncio/event loop errors
- Matches existing pattern from test_transcripts_rtc_ws.py

Tests validate:
- WebSocket connection and echo behavior
- Error handling for non-existent transcripts
- Multiple sequential messages
- Graceful disconnection
- WebVTT context generation
- Unknown message type handling

Closes fn-1.8 (End-to-end testing)
2026-01-12 20:17:42 -05:00
Igor Loskutov
8ca5324c1a feat: integrate TranscriptChatModal and button into transcript page 2026-01-12 20:08:59 -05:00
Igor Loskutov
39e0b89e67 feat: add TranscriptChatModal and TranscriptChatButton components 2026-01-12 19:59:01 -05:00
Igor Loskutov
544793a24f chore: mark fn-1.5 as done (Frontend WebSocket hook)
Task fn-1.5 completed - useTranscriptChat React hook already implemented in commit 2dfe82af.

Hook provides:
- WebSocket connection to /v1/transcripts/{id}/chat endpoint
- Token streaming with ref-based accumulation
- Message history management (user + assistant)
- Memory leak prevention with isMountedRef
- TypeScript type safety
- Proper WebSocket lifecycle and cleanup

Updated task documentation with acceptance criteria and evidence.
2026-01-12 19:49:14 -05:00
Igor Loskutov
088451645a chore: mark fn-1.4 as done (WebSocket route registration) 2026-01-12 19:42:26 -05:00
Igor Loskutov
2dfe82afbc feat: add useTranscriptChat WebSocket hook
Task 5: Frontend WebSocket Hook
- Creates React hook for bidirectional chat WebSocket
- Handles token streaming with proper state accumulation
- Manages conversation history (user + assistant messages)
- Prevents memory leaks with isMounted check
- Proper cleanup on unmount
- Type-safe Message interface

Validated:
- No React dependency issues (removed currentStreamingText from deps)
- No stale closure bugs (using ref for streaming text)
- Proper mounted state tracking
- Lint passes with no errors
- TypeScript types correctly defined
- WebSocket cleanup on unmount

~100 lines
2026-01-12 18:44:09 -05:00
Igor Loskutov
b461ebb488 feat: register transcript chat WebSocket route
- Import transcripts_chat router
- Register /v1/transcripts/{id}/chat endpoint
- Completes LLM streaming integration (fn-1.3)
2026-01-12 18:41:11 -05:00
Igor Loskutov
0b5112cabc feat: add LLM streaming integration to transcript chat
Task 3: LLM Streaming Integration

- Import Settings, ChatMessage, MessageRole from llama-index
- Configure LLM with temperature 0.7 on connection
- Build system message with WebVTT transcript context (max 15k chars)
- Initialize conversation history with system message
- Handle 'message' type from client to trigger LLM streaming
- Stream LLM response using Settings.llm.astream_chat()
- Send tokens incrementally via 'token' messages
- Send 'done' message when streaming completes
- Maintain conversation history across multiple messages
- Add error handling with 'error' message type
- Add message protocol validation test

Implements Tasks 3 & 4 from TASKS.md
2026-01-12 18:28:43 -05:00
Igor Loskutov
316f7b316d feat: add WebVTT context generation to chat WebSocket endpoint
- Import topics_to_webvtt_named and recordings controller
- Add _get_is_multitrack helper function
- Generate WebVTT context on WebSocket connection
- Add get_context message type to retrieve WebVTT
- Maintain backward compatibility with echo for other messages
- Add test fixture and test for WebVTT context generation

Implements task fn-1.2: WebVTT context generation for transcript chat
2026-01-12 18:24:47 -05:00
18 changed files with 734 additions and 751 deletions

View File

@@ -1,12 +1,5 @@
# Changelog # Changelog
## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20)
### Features
* worker affinity ([#819](https://github.com/Monadical-SAS/reflector/issues/819)) ([3b6540e](https://github.com/Monadical-SAS/reflector/commit/3b6540eae5b597449f98661bdf15483b77be3268))
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26) ## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)

View File

@@ -34,7 +34,7 @@ services:
environment: environment:
ENTRYPOINT: beat ENTRYPOINT: beat
hatchet-worker-cpu: hatchet-worker:
build: build:
context: server context: server
volumes: volumes:
@@ -43,20 +43,7 @@ services:
env_file: env_file:
- ./server/.env - ./server/.env
environment: environment:
ENTRYPOINT: hatchet-worker-cpu ENTRYPOINT: hatchet-worker
depends_on:
hatchet:
condition: service_healthy
hatchet-worker-llm:
build:
context: server
volumes:
- ./server/:/app/
- /app/.venv
env_file:
- ./server/.env
environment:
ENTRYPOINT: hatchet-worker-llm
depends_on: depends_on:
hatchet: hatchet:
condition: service_healthy condition: service_healthy

View File

@@ -131,15 +131,6 @@ if [ -z "$DIARIZER_URL" ]; then
fi fi
echo " -> $DIARIZER_URL" echo " -> $DIARIZER_URL"
echo ""
echo "Deploying mixdown (CPU audio processing)..."
MIXDOWN_URL=$(modal deploy reflector_mixdown.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
if [ -z "$MIXDOWN_URL" ]; then
echo "Error: Failed to deploy mixdown. Check Modal dashboard for details."
exit 1
fi
echo " -> $MIXDOWN_URL"
# --- Output Configuration --- # --- Output Configuration ---
echo "" echo ""
echo "==========================================" echo "=========================================="
@@ -156,8 +147,4 @@ echo ""
echo "DIARIZATION_BACKEND=modal" echo "DIARIZATION_BACKEND=modal"
echo "DIARIZATION_URL=$DIARIZER_URL" echo "DIARIZATION_URL=$DIARIZER_URL"
echo "DIARIZATION_MODAL_API_KEY=$API_KEY" echo "DIARIZATION_MODAL_API_KEY=$API_KEY"
echo ""
echo "MIXDOWN_BACKEND=modal"
echo "MIXDOWN_URL=$MIXDOWN_URL"
echo "MIXDOWN_MODAL_API_KEY=$API_KEY"
echo "# --- End Modal Configuration ---" echo "# --- End Modal Configuration ---"

View File

@@ -1,379 +0,0 @@
"""
Reflector GPU backend - audio mixdown
======================================
CPU-intensive audio mixdown service for combining multiple audio tracks.
Uses PyAV filter graph (amix) for high-quality audio mixing.
"""
import os
import tempfile
import time
from fractions import Fraction
import modal
MIXDOWN_TIMEOUT = 900 # 15 minutes
SCALEDOWN_WINDOW = 60 # 1 minute idle before shutdown
app = modal.App("reflector-mixdown")
# CPU-based image (no GPU needed for audio processing)
image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("ffmpeg") # Required by PyAV
.pip_install(
"av==13.1.0", # PyAV for audio processing
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
"fastapi==0.115.12", # API framework
)
)
@app.function(
cpu=4.0, # 4 CPU cores for audio processing
timeout=MIXDOWN_TIMEOUT,
scaledown_window=SCALEDOWN_WINDOW,
secrets=[modal.Secret.from_name("reflector-gpu")],
image=image,
)
@modal.concurrent(max_inputs=10)
@modal.asgi_app()
def web():
import logging
import secrets
import shutil
import av
import requests
from av.audio.resampler import AudioResampler
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Validate API key exists at startup
API_KEY = os.environ.get("REFLECTOR_GPU_APIKEY")
if not API_KEY:
raise RuntimeError("REFLECTOR_GPU_APIKEY not configured in Modal secrets")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
# Use constant-time comparison to prevent timing attacks
if secrets.compare_digest(apikey, API_KEY):
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class MixdownRequest(BaseModel):
track_urls: list[str]
output_url: str
target_sample_rate: int = 48000
expected_duration_sec: float | None = None
class MixdownResponse(BaseModel):
duration_ms: float
tracks_mixed: int
audio_uploaded: bool
def download_track(url: str, temp_dir: str, index: int) -> str:
"""Download track from presigned URL to temp file using streaming."""
logger.info(f"Downloading track {index + 1}")
response = requests.get(url, stream=True, timeout=300)
if response.status_code == 404:
raise HTTPException(status_code=404, detail=f"Track {index} not found")
if response.status_code == 403:
raise HTTPException(
status_code=403, detail=f"Track {index} presigned URL expired"
)
response.raise_for_status()
temp_path = os.path.join(temp_dir, f"track_{index}.webm")
total_bytes = 0
with open(temp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
total_bytes += len(chunk)
logger.info(f"Track {index + 1} downloaded: {total_bytes} bytes")
return temp_path
def mixdown_tracks_modal(
track_paths: list[str],
output_path: str,
target_sample_rate: int,
expected_duration_sec: float | None,
logger,
) -> float:
"""Mix multiple audio tracks using PyAV filter graph.
Args:
track_paths: List of local file paths to audio tracks
output_path: Local path for output MP3 file
target_sample_rate: Sample rate for output (Hz)
expected_duration_sec: Optional fallback duration if container metadata unavailable
logger: Logger instance for progress tracking
Returns:
Duration in milliseconds
"""
logger.info(f"Starting mixdown of {len(track_paths)} tracks")
# Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
graph = av.filter.Graph()
inputs = []
for idx in range(len(track_paths)):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
# Connect inputs to mixer (no delays for Modal implementation)
for idx, in_ctx in enumerate(inputs):
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# Open all containers
containers = []
try:
for i, path in enumerate(track_paths):
try:
c = av.open(path)
containers.append(c)
except Exception as e:
logger.warning(
f"Failed to open container {i}: {e}",
)
if not containers:
raise ValueError("Could not open any track containers")
# Calculate total duration for progress reporting
max_duration_sec = 0.0
for c in containers:
if c.duration is not None:
dur_sec = c.duration / av.time_base
max_duration_sec = max(max_duration_sec, dur_sec)
if max_duration_sec == 0.0 and expected_duration_sec:
max_duration_sec = expected_duration_sec
# Setup output container
out_container = av.open(output_path, "w", format="mp3")
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
current_max_time = 0.0
last_log_time = time.monotonic()
start_time = time.monotonic()
total_duration = 0
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None) # Signal end of stream
continue
if frame.sample_rate != target_sample_rate:
continue
# Progress logging (every 5 seconds)
if frame.time is not None:
current_max_time = max(current_max_time, frame.time)
now = time.monotonic()
if now - last_log_time >= 5.0:
elapsed = now - start_time
if max_duration_sec > 0:
progress_pct = min(
100.0, (current_max_time / max_duration_sec) * 100
)
logger.info(
f"Mixdown progress: {progress_pct:.1f}% @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
)
else:
logger.info(
f"Mixdown progress: @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
)
last_log_time = now
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
# Pull mixed frames from sink and encode
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
# Encode and mux
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush remaining frames from filter graph
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush encoder
for packet in out_stream.encode():
out_container.mux(packet)
total_duration += packet.duration
# Calculate duration in milliseconds
if total_duration > 0:
# Use the same calculation as AudioFileWriterProcessor
duration_ms = round(
float(total_duration * out_stream.time_base * 1000), 2
)
else:
duration_ms = 0.0
out_container.close()
logger.info(f"Mixdown complete: duration={duration_ms}ms")
finally:
# Cleanup all containers
for c in containers:
if c is not None:
try:
c.close()
except Exception:
pass
return duration_ms
@app.post("/v1/audio/mixdown", dependencies=[Depends(apikey_auth)])
def mixdown(request: MixdownRequest) -> MixdownResponse:
"""Mix multiple audio tracks into a single MP3 file.
Tracks are downloaded from presigned S3 URLs, mixed using PyAV,
and uploaded to a presigned S3 PUT URL.
"""
if not request.track_urls:
raise HTTPException(status_code=400, detail="No track URLs provided")
logger.info(f"Mixdown request: {len(request.track_urls)} tracks")
temp_dir = tempfile.mkdtemp()
temp_files = []
output_mp3_path = None
try:
# Download all tracks
for i, url in enumerate(request.track_urls):
temp_path = download_track(url, temp_dir, i)
temp_files.append(temp_path)
# Mix tracks
output_mp3_path = os.path.join(temp_dir, "mixed.mp3")
duration_ms = mixdown_tracks_modal(
temp_files,
output_mp3_path,
request.target_sample_rate,
request.expected_duration_sec,
logger,
)
# Upload result to S3
logger.info("Uploading result to S3")
file_size = os.path.getsize(output_mp3_path)
with open(output_mp3_path, "rb") as f:
upload_response = requests.put(
request.output_url, data=f, timeout=300
)
if upload_response.status_code == 403:
raise HTTPException(
status_code=403, detail="Output presigned URL expired"
)
upload_response.raise_for_status()
logger.info(f"Upload complete: {file_size} bytes")
return MixdownResponse(
duration_ms=duration_ms,
tracks_mixed=len(request.track_urls),
audio_uploaded=True,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Mixdown failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Mixdown failed: {str(e)}")
finally:
# Cleanup temp files
for temp_path in temp_files:
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to cleanup temp file {temp_path}: {e}")
if output_mp3_path and os.path.exists(output_mp3_path):
try:
os.unlink(output_mp3_path)
except Exception as e:
logger.warning(f"Failed to cleanup output file {output_mp3_path}: {e}")
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.warning(f"Failed to cleanup temp directory {temp_dir}: {e}")
return app

View File

@@ -18,6 +18,7 @@ from reflector.views.rooms import router as rooms_router
from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router from reflector.views.transcripts import router as transcripts_router
from reflector.views.transcripts_audio import router as transcripts_audio_router from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_chat import router as transcripts_chat_router
from reflector.views.transcripts_participants import ( from reflector.views.transcripts_participants import (
router as transcripts_participants_router, router as transcripts_participants_router,
) )
@@ -90,6 +91,7 @@ app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_speaker_router, prefix="/v1") app.include_router(transcripts_speaker_router, prefix="/v1")
app.include_router(transcripts_upload_router, prefix="/v1") app.include_router(transcripts_upload_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1") app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_chat_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(transcripts_process_router, prefix="/v1") app.include_router(transcripts_process_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")

View File

@@ -0,0 +1,77 @@
"""
Run Hatchet workers for the multitrack pipeline.
Runs as a separate process, just like Celery workers.
Usage:
uv run -m reflector.hatchet.run_workers
# Or via docker:
docker compose exec server uv run -m reflector.hatchet.run_workers
"""
import signal
import sys
from hatchet_sdk.rate_limit import RateLimitDuration
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND
from reflector.logger import logger
from reflector.settings import settings
def main() -> None:
"""Start Hatchet worker polling."""
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting workers")
sys.exit(1)
if not settings.HATCHET_CLIENT_TOKEN:
logger.error("HATCHET_CLIENT_TOKEN is not set")
sys.exit(1)
logger.info(
"Starting Hatchet workers",
debug=settings.HATCHET_DEBUG,
)
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
# Can't use lazy init: decorators need the client object when function is defined.
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows import ( # noqa: PLC0415
daily_multitrack_pipeline,
subject_workflow,
topic_chunk_workflow,
track_workflow,
)
hatchet = HatchetClientManager.get_client()
hatchet.rate_limits.put(
LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND
)
worker = hatchet.worker(
"reflector-pipeline-worker",
workflows=[
daily_multitrack_pipeline,
subject_workflow,
topic_chunk_workflow,
track_workflow,
],
)
def shutdown_handler(signum: int, frame) -> None:
logger.info("Received shutdown signal, stopping workers...")
# Worker cleanup happens automatically on exit
sys.exit(0)
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
logger.info("Starting Hatchet worker polling...")
worker.start()
if __name__ == "__main__":
main()

View File

@@ -1,48 +0,0 @@
"""
CPU-heavy worker pool for audio processing tasks.
Handles ONLY: mixdown_tracks
Configuration:
- slots=1: Only mixdown (already serialized globally with max_runs=1)
- Worker affinity: pool=cpu-heavy
"""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline,
)
from reflector.logger import logger
from reflector.settings import settings
def main():
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting CPU workers")
return
hatchet = HatchetClientManager.get_client()
logger.info(
"Starting Hatchet CPU worker pool (mixdown only)",
worker_name="cpu-worker-pool",
slots=1,
labels={"pool": "cpu-heavy"},
)
cpu_worker = hatchet.worker(
"cpu-worker-pool",
slots=1, # Only 1 mixdown at a time (already serialized globally)
labels={
"pool": "cpu-heavy",
},
workflows=[daily_multitrack_pipeline],
)
try:
cpu_worker.start()
except KeyboardInterrupt:
logger.info("Received shutdown signal, stopping CPU workers...")
if __name__ == "__main__":
main()

View File

@@ -1,56 +0,0 @@
"""
LLM/I/O worker pool for all non-CPU tasks.
Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration)
"""
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline,
)
from reflector.hatchet.workflows.subject_processing import subject_workflow
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
from reflector.hatchet.workflows.track_processing import track_workflow
from reflector.logger import logger
from reflector.settings import settings
SLOTS = 10
WORKER_NAME = "llm-worker-pool"
POOL = "llm-io"
def main():
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting LLM workers")
return
hatchet = HatchetClientManager.get_client()
logger.info(
"Starting Hatchet LLM worker pool (all tasks except mixdown)",
worker_name=WORKER_NAME,
slots=SLOTS,
labels={"pool": POOL},
)
llm_worker = hatchet.worker(
WORKER_NAME,
slots=SLOTS, # not all slots are probably used
labels={
"pool": POOL,
},
workflows=[
daily_multitrack_pipeline,
topic_chunk_workflow,
subject_workflow,
track_workflow,
],
)
try:
llm_worker.start()
except KeyboardInterrupt:
logger.info("Received shutdown signal, stopping LLM workers...")
if __name__ == "__main__":
main()

View File

@@ -23,12 +23,7 @@ from pathlib import Path
from typing import Any, Callable, Coroutine, Protocol, TypeVar from typing import Any, Callable, Coroutine, Protocol, TypeVar
import httpx import httpx
from hatchet_sdk import ( from hatchet_sdk import Context
ConcurrencyExpression,
ConcurrencyLimitStrategy,
Context,
)
from hatchet_sdk.labels import DesiredWorkerLabel
from pydantic import BaseModel from pydantic import BaseModel
from reflector.dailyco_api.client import DailyApiClient from reflector.dailyco_api.client import DailyApiClient
@@ -472,24 +467,10 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
parents=[process_tracks], parents=[process_tracks],
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
retries=3, retries=3,
desired_worker_labels={
"pool": DesiredWorkerLabel(
value="cpu-heavy",
required=True,
weight=100,
),
},
concurrency=[
ConcurrencyExpression(
expression="'mixdown-global'",
max_runs=1, # serialize mixdown to prevent resource contention
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue
)
],
) )
@with_error_handling(TaskName.MIXDOWN_TRACKS) @with_error_handling(TaskName.MIXDOWN_TRACKS)
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
"""Mix all padded tracks into single audio file using PyAV or Modal backend.""" """Mix all padded tracks into single audio file using PyAV (same as Celery)."""
ctx.log("mixdown_tracks: mixing padded tracks into single audio file") ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
track_result = ctx.task_output(process_tracks) track_result = ctx.task_output(process_tracks)
@@ -513,7 +494,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
storage = _spawn_storage() storage = _spawn_storage()
# Presign URLs for padded tracks (same expiration for both backends) # Presign URLs on demand (avoids stale URLs on workflow replay)
padded_urls = [] padded_urls = []
for track_info in padded_tracks: for track_info in padded_tracks:
if track_info.key: if track_info.key:
@@ -534,79 +515,13 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
logger.error("Mixdown failed - no decodable audio frames found") logger.error("Mixdown failed - no decodable audio frames found")
raise ValueError("No decodable audio frames in any track") raise ValueError("No decodable audio frames in any track")
output_key = f"{input.transcript_id}/audio.mp3"
# Conditional: Modal or local backend
if settings.MIXDOWN_BACKEND == "modal":
ctx.log("mixdown_tracks: using Modal backend")
# Presign PUT URL for output (Modal will upload directly)
output_url = await storage.get_file_url(
output_key,
operation="put_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
)
from reflector.processors.audio_mixdown_modal import ( # noqa: PLC0415
AudioMixdownModalProcessor,
)
try:
processor = AudioMixdownModalProcessor()
result = await processor.mixdown(
track_urls=valid_urls,
output_url=output_url,
target_sample_rate=target_sample_rate,
expected_duration_sec=recording_duration
if recording_duration > 0
else None,
)
duration_ms = result.duration_ms
tracks_mixed = result.tracks_mixed
ctx.log(
f"mixdown_tracks: Modal returned duration={duration_ms}ms, tracks={tracks_mixed}"
)
except httpx.HTTPStatusError as e:
error_detail = e.response.text if hasattr(e.response, "text") else str(e)
logger.error(
"[Hatchet] Modal mixdown HTTP error",
transcript_id=input.transcript_id,
status_code=e.response.status_code if hasattr(e, "response") else None,
error=error_detail,
)
raise RuntimeError(
f"Modal mixdown failed with HTTP {e.response.status_code}: {error_detail}"
)
except httpx.TimeoutException:
logger.error(
"[Hatchet] Modal mixdown timeout",
transcript_id=input.transcript_id,
timeout=settings.MIXDOWN_TIMEOUT,
)
raise RuntimeError(
f"Modal mixdown timeout after {settings.MIXDOWN_TIMEOUT}s"
)
except ValueError as e:
logger.error(
"[Hatchet] Modal mixdown validation error",
transcript_id=input.transcript_id,
error=str(e),
)
raise
else:
ctx.log("mixdown_tracks: using local backend")
# Existing local implementation
output_path = tempfile.mktemp(suffix=".mp3") output_path = tempfile.mktemp(suffix=".mp3")
duration_ms_callback_capture_container = [0.0] duration_ms_callback_capture_container = [0.0]
async def capture_duration(d): async def capture_duration(d):
duration_ms_callback_capture_container[0] = d duration_ms_callback_capture_container[0] = d
writer = AudioFileWriterProcessor( writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration)
path=output_path, on_duration=capture_duration
)
await mixdown_tracks_pyav( await mixdown_tracks_pyav(
valid_urls, valid_urls,
@@ -615,23 +530,18 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
offsets_seconds=None, offsets_seconds=None,
logger=logger, logger=logger,
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS), progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
expected_duration_sec=recording_duration expected_duration_sec=recording_duration if recording_duration > 0 else None,
if recording_duration > 0
else None,
) )
await writer.flush() await writer.flush()
file_size = Path(output_path).stat().st_size file_size = Path(output_path).stat().st_size
storage_path = f"{input.transcript_id}/audio.mp3"
with open(output_path, "rb") as mixed_file: with open(output_path, "rb") as mixed_file:
await storage.put_file(output_key, mixed_file) await storage.put_file(storage_path, mixed_file)
Path(output_path).unlink(missing_ok=True) Path(output_path).unlink(missing_ok=True)
duration_ms = duration_ms_callback_capture_container[0]
tracks_mixed = len(valid_urls)
ctx.log(f"mixdown_tracks: local mixdown uploaded {file_size} bytes")
# Update DB (same for both backends)
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
@@ -641,12 +551,12 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
transcript, {"audio_location": "storage"} transcript, {"audio_location": "storage"}
) )
ctx.log(f"mixdown_tracks complete: uploaded to {output_key}") ctx.log(f"mixdown_tracks complete: uploaded {file_size} bytes to {storage_path}")
return MixdownResult( return MixdownResult(
audio_key=output_key, audio_key=storage_path,
duration=duration_ms, duration=duration_ms_callback_capture_container[0],
tracks_mixed=tracks_mixed, tracks_mixed=len(valid_urls),
) )

View File

@@ -7,11 +7,7 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
from datetime import timedelta from datetime import timedelta
from hatchet_sdk import ( from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context
ConcurrencyExpression,
ConcurrencyLimitStrategy,
Context,
)
from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.rate_limit import RateLimit
from pydantic import BaseModel from pydantic import BaseModel
@@ -38,13 +34,11 @@ hatchet = HatchetClientManager.get_client()
topic_chunk_workflow = hatchet.workflow( topic_chunk_workflow = hatchet.workflow(
name="TopicChunkProcessing", name="TopicChunkProcessing",
input_validator=TopicChunkInput, input_validator=TopicChunkInput,
concurrency=[ concurrency=ConcurrencyExpression(
ConcurrencyExpression(
expression="'global'", # constant string = global limit across all runs expression="'global'", # constant string = global limit across all runs
max_runs=20, max_runs=20,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
) ),
],
) )

View File

@@ -1,89 +0,0 @@
"""
Modal.com backend for audio mixdown.
Uses Modal's CPU containers to offload audio mixing from Hatchet workers.
Communicates via presigned S3 URLs for both input and output.
"""
import httpx
from pydantic import BaseModel
from reflector.settings import settings
class MixdownResponse(BaseModel):
"""Response from Modal mixdown endpoint."""
duration_ms: float
tracks_mixed: int
audio_uploaded: bool
class AudioMixdownModalProcessor:
"""Audio mixdown processor using Modal.com CPU backend.
Sends track URLs (presigned GET) and output URL (presigned PUT) to Modal.
Modal handles download, mixdown via PyAV, and upload.
"""
def __init__(self, modal_api_key: str | None = None):
if not settings.MIXDOWN_URL:
raise ValueError("MIXDOWN_URL required to use AudioMixdownModalProcessor")
self.mixdown_url = settings.MIXDOWN_URL + "/v1"
self.timeout = settings.MIXDOWN_TIMEOUT
self.modal_api_key = modal_api_key or settings.MIXDOWN_MODAL_API_KEY
if not self.modal_api_key:
raise ValueError(
"MIXDOWN_MODAL_API_KEY required to use AudioMixdownModalProcessor"
)
async def mixdown(
self,
track_urls: list[str],
output_url: str,
target_sample_rate: int,
expected_duration_sec: float | None = None,
) -> MixdownResponse:
"""Mix multiple audio tracks via Modal backend.
Args:
track_urls: List of presigned GET URLs for audio tracks (non-empty)
output_url: Presigned PUT URL for output MP3
target_sample_rate: Sample rate for output (Hz, must be positive)
expected_duration_sec: Optional fallback duration if container metadata unavailable
Returns:
MixdownResponse with duration_ms, tracks_mixed, audio_uploaded
Raises:
ValueError: If track_urls is empty or target_sample_rate invalid
httpx.HTTPStatusError: On HTTP errors (404, 403, 500, etc.)
httpx.TimeoutException: On timeout
"""
# Validate inputs
if not track_urls:
raise ValueError("track_urls cannot be empty")
if target_sample_rate <= 0:
raise ValueError(
f"target_sample_rate must be positive, got {target_sample_rate}"
)
if expected_duration_sec is not None and expected_duration_sec < 0:
raise ValueError(
f"expected_duration_sec cannot be negative, got {expected_duration_sec}"
)
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.mixdown_url}/audio/mixdown",
headers={"Authorization": f"Bearer {self.modal_api_key}"},
json={
"track_urls": track_urls,
"output_url": output_url,
"target_sample_rate": target_sample_rate,
"expected_duration_sec": expected_duration_sec,
},
)
response.raise_for_status()
return MixdownResponse(**response.json())

View File

@@ -98,17 +98,6 @@ class Settings(BaseSettings):
# Diarization: local pyannote.audio # Diarization: local pyannote.audio
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
# Audio Mixdown
# backends:
# - local: in-process PyAV mixdown (runs in same process as Hatchet worker)
# - modal: HTTP API client to Modal.com CPU container
MIXDOWN_BACKEND: str = "local"
MIXDOWN_URL: str | None = None
MIXDOWN_TIMEOUT: int = 900 # 15 minutes
# Mixdown: modal backend
MIXDOWN_MODAL_API_KEY: str | None = None
# Sentry # Sentry
SENTRY_DSN: str | None = None SENTRY_DSN: str | None = None

View File

@@ -0,0 +1,133 @@
"""
Transcripts chat API
====================
WebSocket endpoint for bidirectional chat with LLM about transcript content.
"""
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from reflector.auth.auth_jwt import JWTAuth
from reflector.db.recordings import recordings_controller
from reflector.db.transcripts import transcripts_controller
from reflector.db.users import user_controller
from reflector.llm import LLM
from reflector.settings import settings
from reflector.utils.transcript_formats import topics_to_webvtt_named
router = APIRouter()
async def _get_is_multitrack(transcript) -> bool:
"""Detect if transcript is from multitrack recording."""
if not transcript.recording_id:
return False
recording = await recordings_controller.get_by_id(transcript.recording_id)
return recording is not None and recording.is_multitrack
@router.websocket("/transcripts/{transcript_id}/chat")
async def transcript_chat_websocket(
transcript_id: str,
websocket: WebSocket,
):
"""WebSocket endpoint for chatting with LLM about transcript content."""
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
token: Optional[str] = None
negotiated_subprotocol: Optional[str] = None
if len(parts) >= 2 and parts[0].lower() == "bearer":
negotiated_subprotocol = "bearer"
token = parts[1]
user_id: Optional[str] = None
if token:
try:
payload = JWTAuth().verify_token(token)
authentik_uid = payload.get("sub")
if authentik_uid:
user = await user_controller.get_by_authentik_uid(authentik_uid)
if user:
user_id = user.id
except Exception:
# Auth failed - continue as anonymous
pass
# Get transcript (respects user_id for private transcripts)
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
return
# 2. Accept connection (with negotiated subprotocol if present)
await websocket.accept(subprotocol=negotiated_subprotocol)
# 3. Generate WebVTT context
is_multitrack = await _get_is_multitrack(transcript)
webvtt = topics_to_webvtt_named(
transcript.topics, transcript.participants, is_multitrack
)
# Truncate if needed (15k char limit for POC)
webvtt_truncated = webvtt[:15000] if len(webvtt) > 15000 else webvtt
# 4. Configure LLM
llm = LLM(settings=settings, temperature=0.7)
# 5. System message with transcript context
system_msg = f"""You are analyzing this meeting transcript (WebVTT):
{webvtt_truncated}
Answer questions about content, speakers, timeline. Include timestamps when relevant."""
# 6. Conversation history
conversation_history = [ChatMessage(role=MessageRole.SYSTEM, content=system_msg)]
try:
# 7. Message loop
while True:
data = await websocket.receive_json()
if data.get("type") == "get_context":
# Return WebVTT context (for debugging/testing)
await websocket.send_json({"type": "context", "webvtt": webvtt})
continue
if data.get("type") != "message":
# Echo unknown types for backward compatibility
await websocket.send_json({"type": "echo", "data": data})
continue
# Add user message to history
user_msg = ChatMessage(role=MessageRole.USER, content=data.get("text", ""))
conversation_history.append(user_msg)
# Stream LLM response
assistant_msg = ""
chat_stream = await Settings.llm.astream_chat(conversation_history)
async for chunk in chat_stream:
token = chunk.delta or ""
if token:
await websocket.send_json({"type": "token", "text": token})
assistant_msg += token
# Save assistant response to history
conversation_history.append(
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_msg)
)
await websocket.send_json({"type": "done"})
except WebSocketDisconnect:
pass
except Exception as e:
await websocket.send_json({"type": "error", "message": str(e)})

View File

@@ -7,10 +7,8 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then
uv run celery -A reflector.worker.app worker --loglevel=info uv run celery -A reflector.worker.app worker --loglevel=info
elif [ "${ENTRYPOINT}" = "beat" ]; then elif [ "${ENTRYPOINT}" = "beat" ]; then
uv run celery -A reflector.worker.app beat --loglevel=info uv run celery -A reflector.worker.app beat --loglevel=info
elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then
uv run python -m reflector.hatchet.run_workers_cpu uv run python -m reflector.hatchet.run_workers
elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then
uv run python -m reflector.hatchet.run_workers_llm
else else
echo "Unknown command" echo "Unknown command"
fi fi

View File

@@ -0,0 +1,234 @@
"""Tests for transcript chat WebSocket endpoint."""
import asyncio
import threading
import time
from pathlib import Path
import pytest
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
from reflector.db.transcripts import (
SourceKind,
TranscriptParticipant,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Word
@pytest.fixture
def chat_appserver(tmpdir, setup_database):
"""Start a real HTTP server for WebSocket testing."""
from reflector.app import app
from reflector.db import get_database
from reflector.settings import settings
DATA_DIR = settings.DATA_DIR
settings.DATA_DIR = Path(tmpdir)
# Start server in separate thread with its own event loop
host = "127.0.0.1"
port = 1256 # Different port from rtc tests
server_started = threading.Event()
server_exception = None
server_instance = None
def run_server():
nonlocal server_exception, server_instance
try:
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
# Initialize database connection in this event loop
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait for server to start
server_started.wait(timeout=30)
if server_exception:
raise server_exception
# Wait for server to be fully ready
time.sleep(1)
yield server_instance, host, port
# Stop server
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
settings.DATA_DIR = DATA_DIR
@pytest.fixture
async def test_transcript(setup_database):
"""Create a test transcript for WebSocket tests."""
transcript = await transcripts_controller.add(
name="Test Transcript for Chat", source_kind=SourceKind.FILE
)
return transcript
@pytest.fixture
async def test_transcript_with_content(setup_database):
"""Create a test transcript with actual content for WebVTT generation."""
transcript = await transcripts_controller.add(
name="Test Transcript with Content", source_kind=SourceKind.FILE
)
# Add participants
await transcripts_controller.update(
transcript,
{
"participants": [
TranscriptParticipant(id="1", speaker=0, name="Alice").model_dump(),
TranscriptParticipant(id="2", speaker=1, name="Bob").model_dump(),
]
},
)
# Add topic with words
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Introduction",
summary="Opening remarks",
timestamp=0.0,
words=[
Word(text="Hello ", start=0.0, end=1.0, speaker=0),
Word(text="everyone.", start=1.0, end=2.0, speaker=0),
Word(text="Hi ", start=2.0, end=3.0, speaker=1),
Word(text="there!", start=3.0, end=4.0, speaker=1),
],
),
)
return transcript
@pytest.mark.asyncio
async def test_chat_websocket_connection_success(test_transcript, chat_appserver):
"""Test successful WebSocket connection to chat endpoint."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send unknown message type to test echo behavior
await ws.send_json({"type": "test", "text": "Hello"})
# Should receive echo for unknown types
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == "test"
@pytest.mark.asyncio
async def test_chat_websocket_nonexistent_transcript(chat_appserver):
"""Test WebSocket connection fails for nonexistent transcript."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
# Connection should fail or disconnect immediately for non-existent transcript
# Different behavior from successful connection
with pytest.raises(Exception): # Will raise on connection or first operation
async with aconnect_ws(f"{base_url}/transcripts/nonexistent-id/chat") as ws:
await ws.send_json({"type": "message", "text": "Hello"})
await ws.receive_json()
@pytest.mark.asyncio
async def test_chat_websocket_multiple_messages(test_transcript, chat_appserver):
"""Test sending multiple messages through WebSocket."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send multiple unknown message types (testing echo behavior)
messages = ["First message", "Second message", "Third message"]
for i, msg in enumerate(messages):
await ws.send_json({"type": f"test{i}", "text": msg})
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == f"test{i}"
assert response["data"]["text"] == msg
@pytest.mark.asyncio
async def test_chat_websocket_disconnect_graceful(test_transcript, chat_appserver):
"""Test WebSocket disconnects gracefully."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
await ws.send_json({"type": "message", "text": "Hello"})
await ws.receive_json()
# Close handled by context manager - should not raise
@pytest.mark.asyncio
async def test_chat_websocket_context_generation(
test_transcript_with_content, chat_appserver
):
"""Test WebVTT context is generated on connection."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(
f"{base_url}/transcripts/{test_transcript_with_content.id}/chat"
) as ws:
# Request context
await ws.send_json({"type": "get_context"})
# Receive context response
response = await ws.receive_json()
assert response["type"] == "context"
assert "webvtt" in response
# Verify WebVTT format
webvtt = response["webvtt"]
assert webvtt.startswith("WEBVTT")
assert "<v Alice>" in webvtt
assert "<v Bob>" in webvtt
assert "Hello everyone." in webvtt
assert "Hi there!" in webvtt
@pytest.mark.asyncio
async def test_chat_websocket_unknown_message_type(test_transcript, chat_appserver):
"""Test unknown message types are echoed back."""
server, host, port = chat_appserver
base_url = f"ws://{host}:{port}/v1"
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
# Send unknown message type
await ws.send_json({"type": "unknown", "data": "test"})
# Should receive echo
response = await ws.receive_json()
assert response["type"] == "echo"
assert response["data"]["type"] == "unknown"

View File

@@ -0,0 +1,103 @@
"use client";
import { useState } from "react";
import { Box, Dialog, Input, IconButton } from "@chakra-ui/react";
import { MessageCircle } from "lucide-react";
import Markdown from "react-markdown";
import "../../styles/markdown.css";
import type { Message } from "./useTranscriptChat";
interface TranscriptChatModalProps {
open: boolean;
onClose: () => void;
messages: Message[];
sendMessage: (text: string) => void;
isStreaming: boolean;
currentStreamingText: string;
}
export function TranscriptChatModal({
open,
onClose,
messages,
sendMessage,
isStreaming,
currentStreamingText,
}: TranscriptChatModalProps) {
const [input, setInput] = useState("");
const handleSend = () => {
if (!input.trim()) return;
sendMessage(input);
setInput("");
};
return (
<Dialog.Root open={open} onOpenChange={(e) => !e.open && onClose()}>
<Dialog.Backdrop />
<Dialog.Positioner>
<Dialog.Content maxW="500px" h="600px">
<Dialog.Header>Transcript Chat</Dialog.Header>
<Dialog.Body overflowY="auto">
{messages.map((msg) => (
<Box
key={msg.id}
p={3}
mb={2}
bg={msg.role === "user" ? "blue.50" : "gray.50"}
borderRadius="md"
>
{msg.role === "user" ? (
msg.text
) : (
<div className="markdown">
<Markdown>{msg.text}</Markdown>
</div>
)}
</Box>
))}
{isStreaming && (
<Box p={3} bg="gray.50" borderRadius="md">
<div className="markdown">
<Markdown>{currentStreamingText}</Markdown>
</div>
<Box as="span" className="animate-pulse">
</Box>
</Box>
)}
</Dialog.Body>
<Dialog.Footer>
<Input
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleSend()}
placeholder="Ask about transcript..."
disabled={isStreaming}
/>
</Dialog.Footer>
</Dialog.Content>
</Dialog.Positioner>
</Dialog.Root>
);
}
export function TranscriptChatButton({ onClick }: { onClick: () => void }) {
return (
<IconButton
position="fixed"
bottom="24px"
right="24px"
onClick={onClick}
size="lg"
colorPalette="blue"
borderRadius="full"
aria-label="Open chat"
>
<MessageCircle />
</IconButton>
);
}

View File

@@ -18,9 +18,15 @@ import {
Skeleton, Skeleton,
Text, Text,
Spinner, Spinner,
useDisclosure,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { useTranscriptGet } from "../../../lib/apiHooks"; import { useTranscriptGet } from "../../../lib/apiHooks";
import { TranscriptStatus } from "../../../lib/transcript"; import { TranscriptStatus } from "../../../lib/transcript";
import {
TranscriptChatModal,
TranscriptChatButton,
} from "../TranscriptChatModal";
import { useTranscriptChat } from "../useTranscriptChat";
type TranscriptDetails = { type TranscriptDetails = {
params: Promise<{ params: Promise<{
@@ -53,6 +59,9 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const [finalSummaryElement, setFinalSummaryElement] = const [finalSummaryElement, setFinalSummaryElement] =
useState<HTMLDivElement | null>(null); useState<HTMLDivElement | null>(null);
const { open, onOpen, onClose } = useDisclosure();
const chat = useTranscriptChat(transcriptId);
useEffect(() => { useEffect(() => {
if (!waiting || !transcript.data) return; if (!waiting || !transcript.data) return;
@@ -119,6 +128,15 @@ export default function TranscriptDetails(details: TranscriptDetails) {
return ( return (
<> <>
<TranscriptChatModal
open={open}
onClose={onClose}
messages={chat.messages}
sendMessage={chat.sendMessage}
isStreaming={chat.isStreaming}
currentStreamingText={chat.currentStreamingText}
/>
<TranscriptChatButton onClick={onOpen} />
<Grid <Grid
templateColumns="1fr" templateColumns="1fr"
templateRows="auto minmax(0, 1fr)" templateRows="auto minmax(0, 1fr)"

View File

@@ -0,0 +1,130 @@
"use client";
import { useEffect, useState, useRef } from "react";
import { getSession } from "next-auth/react";
import { WEBSOCKET_URL } from "../../lib/apiClient";
import { assertExtendedToken } from "../../lib/types";
export type Message = {
id: string;
role: "user" | "assistant";
text: string;
timestamp: Date;
};
export type UseTranscriptChat = {
messages: Message[];
sendMessage: (text: string) => void;
isStreaming: boolean;
currentStreamingText: string;
};
export const useTranscriptChat = (transcriptId: string): UseTranscriptChat => {
const [messages, setMessages] = useState<Message[]>([]);
const [isStreaming, setIsStreaming] = useState(false);
const [currentStreamingText, setCurrentStreamingText] = useState("");
const wsRef = useRef<WebSocket | null>(null);
const streamingTextRef = useRef<string>("");
const isMountedRef = useRef<boolean>(true);
useEffect(() => {
isMountedRef.current = true;
const connectWebSocket = async () => {
const url = `${WEBSOCKET_URL}/v1/transcripts/${transcriptId}/chat`;
// Get auth token for WebSocket subprotocol
let protocols: string[] | undefined;
try {
const session = await getSession();
if (session) {
const token = assertExtendedToken(session).accessToken;
// Pass token via subprotocol: ["bearer", token]
protocols = ["bearer", token];
}
} catch (error) {
console.warn("Failed to get auth token for WebSocket:", error);
}
const ws = new WebSocket(url, protocols);
wsRef.current = ws;
ws.onopen = () => {
console.log("Chat WebSocket connected");
};
ws.onmessage = (event) => {
if (!isMountedRef.current) return;
const msg = JSON.parse(event.data);
switch (msg.type) {
case "token":
setIsStreaming(true);
streamingTextRef.current += msg.text;
setCurrentStreamingText(streamingTextRef.current);
break;
case "done":
// CRITICAL: Save the text BEFORE resetting the ref
// The setMessages callback may execute later, after ref is reset
const finalText = streamingTextRef.current;
setMessages((prev) => [
...prev,
{
id: Date.now().toString(),
role: "assistant",
text: finalText,
timestamp: new Date(),
},
]);
streamingTextRef.current = "";
setCurrentStreamingText("");
setIsStreaming(false);
break;
case "error":
console.error("Chat error:", msg.message);
setIsStreaming(false);
break;
}
};
ws.onerror = (error) => {
console.error("WebSocket error:", error);
};
ws.onclose = () => {
console.log("Chat WebSocket closed");
};
};
connectWebSocket();
return () => {
isMountedRef.current = false;
if (wsRef.current) {
wsRef.current.close();
}
};
}, [transcriptId]);
const sendMessage = (text: string) => {
if (!wsRef.current) return;
setMessages((prev) => [
...prev,
{
id: Date.now().toString(),
role: "user",
text,
timestamp: new Date(),
},
]);
wsRef.current.send(JSON.stringify({ type: "message", text }));
};
return { messages, sendMessage, isStreaming, currentStreamingText };
};