mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-05 18:36:45 +00:00
Compare commits
5 Commits
feat/trans
...
feature/mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1b790c5a8 | ||
| c8743fdf1c | |||
| 8a293882ad | |||
| d83c4a30b4 | |||
| 3b6540eae5 |
@@ -1,5 +1,12 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20)
|
||||||
|
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* worker affinity ([#819](https://github.com/Monadical-SAS/reflector/issues/819)) ([3b6540e](https://github.com/Monadical-SAS/reflector/commit/3b6540eae5b597449f98661bdf15483b77be3268))
|
||||||
|
|
||||||
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)
|
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
ENTRYPOINT: beat
|
ENTRYPOINT: beat
|
||||||
|
|
||||||
hatchet-worker:
|
hatchet-worker-cpu:
|
||||||
build:
|
build:
|
||||||
context: server
|
context: server
|
||||||
volumes:
|
volumes:
|
||||||
@@ -43,7 +43,20 @@ services:
|
|||||||
env_file:
|
env_file:
|
||||||
- ./server/.env
|
- ./server/.env
|
||||||
environment:
|
environment:
|
||||||
ENTRYPOINT: hatchet-worker
|
ENTRYPOINT: hatchet-worker-cpu
|
||||||
|
depends_on:
|
||||||
|
hatchet:
|
||||||
|
condition: service_healthy
|
||||||
|
hatchet-worker-llm:
|
||||||
|
build:
|
||||||
|
context: server
|
||||||
|
volumes:
|
||||||
|
- ./server/:/app/
|
||||||
|
- /app/.venv
|
||||||
|
env_file:
|
||||||
|
- ./server/.env
|
||||||
|
environment:
|
||||||
|
ENTRYPOINT: hatchet-worker-llm
|
||||||
depends_on:
|
depends_on:
|
||||||
hatchet:
|
hatchet:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|||||||
@@ -131,6 +131,15 @@ if [ -z "$DIARIZER_URL" ]; then
|
|||||||
fi
|
fi
|
||||||
echo " -> $DIARIZER_URL"
|
echo " -> $DIARIZER_URL"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Deploying mixdown (CPU audio processing)..."
|
||||||
|
MIXDOWN_URL=$(modal deploy reflector_mixdown.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
|
||||||
|
if [ -z "$MIXDOWN_URL" ]; then
|
||||||
|
echo "Error: Failed to deploy mixdown. Check Modal dashboard for details."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo " -> $MIXDOWN_URL"
|
||||||
|
|
||||||
# --- Output Configuration ---
|
# --- Output Configuration ---
|
||||||
echo ""
|
echo ""
|
||||||
echo "=========================================="
|
echo "=========================================="
|
||||||
@@ -147,4 +156,8 @@ echo ""
|
|||||||
echo "DIARIZATION_BACKEND=modal"
|
echo "DIARIZATION_BACKEND=modal"
|
||||||
echo "DIARIZATION_URL=$DIARIZER_URL"
|
echo "DIARIZATION_URL=$DIARIZER_URL"
|
||||||
echo "DIARIZATION_MODAL_API_KEY=$API_KEY"
|
echo "DIARIZATION_MODAL_API_KEY=$API_KEY"
|
||||||
|
echo ""
|
||||||
|
echo "MIXDOWN_BACKEND=modal"
|
||||||
|
echo "MIXDOWN_URL=$MIXDOWN_URL"
|
||||||
|
echo "MIXDOWN_MODAL_API_KEY=$API_KEY"
|
||||||
echo "# --- End Modal Configuration ---"
|
echo "# --- End Modal Configuration ---"
|
||||||
|
|||||||
379
gpu/modal_deployments/reflector_mixdown.py
Normal file
379
gpu/modal_deployments/reflector_mixdown.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
"""
|
||||||
|
Reflector GPU backend - audio mixdown
|
||||||
|
======================================
|
||||||
|
|
||||||
|
CPU-intensive audio mixdown service for combining multiple audio tracks.
|
||||||
|
Uses PyAV filter graph (amix) for high-quality audio mixing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
|
import modal
|
||||||
|
|
||||||
|
MIXDOWN_TIMEOUT = 900 # 15 minutes
|
||||||
|
SCALEDOWN_WINDOW = 60 # 1 minute idle before shutdown
|
||||||
|
|
||||||
|
app = modal.App("reflector-mixdown")
|
||||||
|
|
||||||
|
# CPU-based image (no GPU needed for audio processing)
|
||||||
|
image = (
|
||||||
|
modal.Image.debian_slim(python_version="3.12")
|
||||||
|
.apt_install("ffmpeg") # Required by PyAV
|
||||||
|
.pip_install(
|
||||||
|
"av==13.1.0", # PyAV for audio processing
|
||||||
|
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
|
||||||
|
"fastapi==0.115.12", # API framework
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
cpu=4.0, # 4 CPU cores for audio processing
|
||||||
|
timeout=MIXDOWN_TIMEOUT,
|
||||||
|
scaledown_window=SCALEDOWN_WINDOW,
|
||||||
|
secrets=[modal.Secret.from_name("reflector-gpu")],
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
@modal.concurrent(max_inputs=10)
|
||||||
|
@modal.asgi_app()
|
||||||
|
def web():
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import av
|
||||||
|
import requests
|
||||||
|
from av.audio.resampler import AudioResampler
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
# Validate API key exists at startup
|
||||||
|
API_KEY = os.environ.get("REFLECTOR_GPU_APIKEY")
|
||||||
|
if not API_KEY:
|
||||||
|
raise RuntimeError("REFLECTOR_GPU_APIKEY not configured in Modal secrets")
|
||||||
|
|
||||||
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||||
|
# Use constant-time comparison to prevent timing attacks
|
||||||
|
if secrets.compare_digest(apikey, API_KEY):
|
||||||
|
return
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
class MixdownRequest(BaseModel):
|
||||||
|
track_urls: list[str]
|
||||||
|
output_url: str
|
||||||
|
target_sample_rate: int = 48000
|
||||||
|
expected_duration_sec: float | None = None
|
||||||
|
|
||||||
|
class MixdownResponse(BaseModel):
|
||||||
|
duration_ms: float
|
||||||
|
tracks_mixed: int
|
||||||
|
audio_uploaded: bool
|
||||||
|
|
||||||
|
def download_track(url: str, temp_dir: str, index: int) -> str:
|
||||||
|
"""Download track from presigned URL to temp file using streaming."""
|
||||||
|
logger.info(f"Downloading track {index + 1}")
|
||||||
|
response = requests.get(url, stream=True, timeout=300)
|
||||||
|
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Track {index} not found")
|
||||||
|
if response.status_code == 403:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail=f"Track {index} presigned URL expired"
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
temp_path = os.path.join(temp_dir, f"track_{index}.webm")
|
||||||
|
total_bytes = 0
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
total_bytes += len(chunk)
|
||||||
|
|
||||||
|
logger.info(f"Track {index + 1} downloaded: {total_bytes} bytes")
|
||||||
|
return temp_path
|
||||||
|
|
||||||
|
def mixdown_tracks_modal(
|
||||||
|
track_paths: list[str],
|
||||||
|
output_path: str,
|
||||||
|
target_sample_rate: int,
|
||||||
|
expected_duration_sec: float | None,
|
||||||
|
logger,
|
||||||
|
) -> float:
|
||||||
|
"""Mix multiple audio tracks using PyAV filter graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
track_paths: List of local file paths to audio tracks
|
||||||
|
output_path: Local path for output MP3 file
|
||||||
|
target_sample_rate: Sample rate for output (Hz)
|
||||||
|
expected_duration_sec: Optional fallback duration if container metadata unavailable
|
||||||
|
logger: Logger instance for progress tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in milliseconds
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting mixdown of {len(track_paths)} tracks")
|
||||||
|
|
||||||
|
# Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
|
||||||
|
graph = av.filter.Graph()
|
||||||
|
inputs = []
|
||||||
|
|
||||||
|
for idx in range(len(track_paths)):
|
||||||
|
args = (
|
||||||
|
f"time_base=1/{target_sample_rate}:"
|
||||||
|
f"sample_rate={target_sample_rate}:"
|
||||||
|
f"sample_fmt=s32:"
|
||||||
|
f"channel_layout=stereo"
|
||||||
|
)
|
||||||
|
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
|
||||||
|
inputs.append(in_ctx)
|
||||||
|
|
||||||
|
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
|
||||||
|
fmt = graph.add(
|
||||||
|
"aformat",
|
||||||
|
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
|
||||||
|
name="fmt",
|
||||||
|
)
|
||||||
|
sink = graph.add("abuffersink", name="out")
|
||||||
|
|
||||||
|
# Connect inputs to mixer (no delays for Modal implementation)
|
||||||
|
for idx, in_ctx in enumerate(inputs):
|
||||||
|
in_ctx.link_to(mixer, 0, idx)
|
||||||
|
|
||||||
|
mixer.link_to(fmt)
|
||||||
|
fmt.link_to(sink)
|
||||||
|
graph.configure()
|
||||||
|
|
||||||
|
# Open all containers
|
||||||
|
containers = []
|
||||||
|
try:
|
||||||
|
for i, path in enumerate(track_paths):
|
||||||
|
try:
|
||||||
|
c = av.open(path)
|
||||||
|
containers.append(c)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to open container {i}: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not containers:
|
||||||
|
raise ValueError("Could not open any track containers")
|
||||||
|
|
||||||
|
# Calculate total duration for progress reporting
|
||||||
|
max_duration_sec = 0.0
|
||||||
|
for c in containers:
|
||||||
|
if c.duration is not None:
|
||||||
|
dur_sec = c.duration / av.time_base
|
||||||
|
max_duration_sec = max(max_duration_sec, dur_sec)
|
||||||
|
if max_duration_sec == 0.0 and expected_duration_sec:
|
||||||
|
max_duration_sec = expected_duration_sec
|
||||||
|
|
||||||
|
# Setup output container
|
||||||
|
out_container = av.open(output_path, "w", format="mp3")
|
||||||
|
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
|
||||||
|
|
||||||
|
decoders = [c.decode(audio=0) for c in containers]
|
||||||
|
active = [True] * len(decoders)
|
||||||
|
resamplers = [
|
||||||
|
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
|
||||||
|
for _ in decoders
|
||||||
|
]
|
||||||
|
|
||||||
|
current_max_time = 0.0
|
||||||
|
last_log_time = time.monotonic()
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
total_duration = 0
|
||||||
|
|
||||||
|
while any(active):
|
||||||
|
for i, (dec, is_active) in enumerate(zip(decoders, active)):
|
||||||
|
if not is_active:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
frame = next(dec)
|
||||||
|
except StopIteration:
|
||||||
|
active[i] = False
|
||||||
|
inputs[i].push(None) # Signal end of stream
|
||||||
|
continue
|
||||||
|
|
||||||
|
if frame.sample_rate != target_sample_rate:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Progress logging (every 5 seconds)
|
||||||
|
if frame.time is not None:
|
||||||
|
current_max_time = max(current_max_time, frame.time)
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - last_log_time >= 5.0:
|
||||||
|
elapsed = now - start_time
|
||||||
|
if max_duration_sec > 0:
|
||||||
|
progress_pct = min(
|
||||||
|
100.0, (current_max_time / max_duration_sec) * 100
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Mixdown progress: {progress_pct:.1f}% @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Mixdown progress: @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
|
||||||
|
)
|
||||||
|
last_log_time = now
|
||||||
|
|
||||||
|
out_frames = resamplers[i].resample(frame) or []
|
||||||
|
for rf in out_frames:
|
||||||
|
rf.sample_rate = target_sample_rate
|
||||||
|
rf.time_base = Fraction(1, target_sample_rate)
|
||||||
|
inputs[i].push(rf)
|
||||||
|
|
||||||
|
# Pull mixed frames from sink and encode
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
mixed = sink.pull()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
mixed.sample_rate = target_sample_rate
|
||||||
|
mixed.time_base = Fraction(1, target_sample_rate)
|
||||||
|
|
||||||
|
# Encode and mux
|
||||||
|
for packet in out_stream.encode(mixed):
|
||||||
|
out_container.mux(packet)
|
||||||
|
total_duration += packet.duration
|
||||||
|
|
||||||
|
# Flush remaining frames from filter graph
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
mixed = sink.pull()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
mixed.sample_rate = target_sample_rate
|
||||||
|
mixed.time_base = Fraction(1, target_sample_rate)
|
||||||
|
|
||||||
|
for packet in out_stream.encode(mixed):
|
||||||
|
out_container.mux(packet)
|
||||||
|
total_duration += packet.duration
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
for packet in out_stream.encode():
|
||||||
|
out_container.mux(packet)
|
||||||
|
total_duration += packet.duration
|
||||||
|
|
||||||
|
# Calculate duration in milliseconds
|
||||||
|
if total_duration > 0:
|
||||||
|
# Use the same calculation as AudioFileWriterProcessor
|
||||||
|
duration_ms = round(
|
||||||
|
float(total_duration * out_stream.time_base * 1000), 2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
duration_ms = 0.0
|
||||||
|
|
||||||
|
out_container.close()
|
||||||
|
logger.info(f"Mixdown complete: duration={duration_ms}ms")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup all containers
|
||||||
|
for c in containers:
|
||||||
|
if c is not None:
|
||||||
|
try:
|
||||||
|
c.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return duration_ms
|
||||||
|
|
||||||
|
@app.post("/v1/audio/mixdown", dependencies=[Depends(apikey_auth)])
|
||||||
|
def mixdown(request: MixdownRequest) -> MixdownResponse:
|
||||||
|
"""Mix multiple audio tracks into a single MP3 file.
|
||||||
|
|
||||||
|
Tracks are downloaded from presigned S3 URLs, mixed using PyAV,
|
||||||
|
and uploaded to a presigned S3 PUT URL.
|
||||||
|
"""
|
||||||
|
if not request.track_urls:
|
||||||
|
raise HTTPException(status_code=400, detail="No track URLs provided")
|
||||||
|
|
||||||
|
logger.info(f"Mixdown request: {len(request.track_urls)} tracks")
|
||||||
|
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
temp_files = []
|
||||||
|
output_mp3_path = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download all tracks
|
||||||
|
for i, url in enumerate(request.track_urls):
|
||||||
|
temp_path = download_track(url, temp_dir, i)
|
||||||
|
temp_files.append(temp_path)
|
||||||
|
|
||||||
|
# Mix tracks
|
||||||
|
output_mp3_path = os.path.join(temp_dir, "mixed.mp3")
|
||||||
|
duration_ms = mixdown_tracks_modal(
|
||||||
|
temp_files,
|
||||||
|
output_mp3_path,
|
||||||
|
request.target_sample_rate,
|
||||||
|
request.expected_duration_sec,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload result to S3
|
||||||
|
logger.info("Uploading result to S3")
|
||||||
|
file_size = os.path.getsize(output_mp3_path)
|
||||||
|
with open(output_mp3_path, "rb") as f:
|
||||||
|
upload_response = requests.put(
|
||||||
|
request.output_url, data=f, timeout=300
|
||||||
|
)
|
||||||
|
|
||||||
|
if upload_response.status_code == 403:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="Output presigned URL expired"
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_response.raise_for_status()
|
||||||
|
logger.info(f"Upload complete: {file_size} bytes")
|
||||||
|
|
||||||
|
return MixdownResponse(
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
tracks_mixed=len(request.track_urls),
|
||||||
|
audio_uploaded=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Mixdown failed: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Mixdown failed: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup temp files
|
||||||
|
for temp_path in temp_files:
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cleanup temp file {temp_path}: {e}")
|
||||||
|
|
||||||
|
if output_mp3_path and os.path.exists(output_mp3_path):
|
||||||
|
try:
|
||||||
|
os.unlink(output_mp3_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cleanup output file {output_mp3_path}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cleanup temp directory {temp_dir}: {e}")
|
||||||
|
|
||||||
|
return app
|
||||||
@@ -18,7 +18,6 @@ from reflector.views.rooms import router as rooms_router
|
|||||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||||
from reflector.views.transcripts import router as transcripts_router
|
from reflector.views.transcripts import router as transcripts_router
|
||||||
from reflector.views.transcripts_audio import router as transcripts_audio_router
|
from reflector.views.transcripts_audio import router as transcripts_audio_router
|
||||||
from reflector.views.transcripts_chat import router as transcripts_chat_router
|
|
||||||
from reflector.views.transcripts_participants import (
|
from reflector.views.transcripts_participants import (
|
||||||
router as transcripts_participants_router,
|
router as transcripts_participants_router,
|
||||||
)
|
)
|
||||||
@@ -91,7 +90,6 @@ app.include_router(transcripts_participants_router, prefix="/v1")
|
|||||||
app.include_router(transcripts_speaker_router, prefix="/v1")
|
app.include_router(transcripts_speaker_router, prefix="/v1")
|
||||||
app.include_router(transcripts_upload_router, prefix="/v1")
|
app.include_router(transcripts_upload_router, prefix="/v1")
|
||||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||||
app.include_router(transcripts_chat_router, prefix="/v1")
|
|
||||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||||
app.include_router(transcripts_process_router, prefix="/v1")
|
app.include_router(transcripts_process_router, prefix="/v1")
|
||||||
app.include_router(user_router, prefix="/v1")
|
app.include_router(user_router, prefix="/v1")
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
"""
|
|
||||||
Run Hatchet workers for the multitrack pipeline.
|
|
||||||
Runs as a separate process, just like Celery workers.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
uv run -m reflector.hatchet.run_workers
|
|
||||||
|
|
||||||
# Or via docker:
|
|
||||||
docker compose exec server uv run -m reflector.hatchet.run_workers
|
|
||||||
"""
|
|
||||||
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from hatchet_sdk.rate_limit import RateLimitDuration
|
|
||||||
|
|
||||||
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
"""Start Hatchet worker polling."""
|
|
||||||
if not settings.HATCHET_ENABLED:
|
|
||||||
logger.error("HATCHET_ENABLED is False, not starting workers")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not settings.HATCHET_CLIENT_TOKEN:
|
|
||||||
logger.error("HATCHET_CLIENT_TOKEN is not set")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Starting Hatchet workers",
|
|
||||||
debug=settings.HATCHET_DEBUG,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
|
|
||||||
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
|
|
||||||
# Can't use lazy init: decorators need the client object when function is defined.
|
|
||||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
|
||||||
from reflector.hatchet.workflows import ( # noqa: PLC0415
|
|
||||||
daily_multitrack_pipeline,
|
|
||||||
subject_workflow,
|
|
||||||
topic_chunk_workflow,
|
|
||||||
track_workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
hatchet = HatchetClientManager.get_client()
|
|
||||||
|
|
||||||
hatchet.rate_limits.put(
|
|
||||||
LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND
|
|
||||||
)
|
|
||||||
|
|
||||||
worker = hatchet.worker(
|
|
||||||
"reflector-pipeline-worker",
|
|
||||||
workflows=[
|
|
||||||
daily_multitrack_pipeline,
|
|
||||||
subject_workflow,
|
|
||||||
topic_chunk_workflow,
|
|
||||||
track_workflow,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def shutdown_handler(signum: int, frame) -> None:
|
|
||||||
logger.info("Received shutdown signal, stopping workers...")
|
|
||||||
# Worker cleanup happens automatically on exit
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, shutdown_handler)
|
|
||||||
signal.signal(signal.SIGTERM, shutdown_handler)
|
|
||||||
|
|
||||||
logger.info("Starting Hatchet worker polling...")
|
|
||||||
worker.start()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
48
server/reflector/hatchet/run_workers_cpu.py
Normal file
48
server/reflector/hatchet/run_workers_cpu.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
CPU-heavy worker pool for audio processing tasks.
|
||||||
|
Handles ONLY: mixdown_tracks
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
- slots=1: Only mixdown (already serialized globally with max_runs=1)
|
||||||
|
- Worker affinity: pool=cpu-heavy
|
||||||
|
"""
|
||||||
|
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not settings.HATCHET_ENABLED:
|
||||||
|
logger.error("HATCHET_ENABLED is False, not starting CPU workers")
|
||||||
|
return
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting Hatchet CPU worker pool (mixdown only)",
|
||||||
|
worker_name="cpu-worker-pool",
|
||||||
|
slots=1,
|
||||||
|
labels={"pool": "cpu-heavy"},
|
||||||
|
)
|
||||||
|
|
||||||
|
cpu_worker = hatchet.worker(
|
||||||
|
"cpu-worker-pool",
|
||||||
|
slots=1, # Only 1 mixdown at a time (already serialized globally)
|
||||||
|
labels={
|
||||||
|
"pool": "cpu-heavy",
|
||||||
|
},
|
||||||
|
workflows=[daily_multitrack_pipeline],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cpu_worker.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received shutdown signal, stopping CPU workers...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
56
server/reflector/hatchet/run_workers_llm.py
Normal file
56
server/reflector/hatchet/run_workers_llm.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
LLM/I/O worker pool for all non-CPU tasks.
|
||||||
|
Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
)
|
||||||
|
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
||||||
|
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
||||||
|
from reflector.hatchet.workflows.track_processing import track_workflow
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
SLOTS = 10
|
||||||
|
WORKER_NAME = "llm-worker-pool"
|
||||||
|
POOL = "llm-io"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not settings.HATCHET_ENABLED:
|
||||||
|
logger.error("HATCHET_ENABLED is False, not starting LLM workers")
|
||||||
|
return
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting Hatchet LLM worker pool (all tasks except mixdown)",
|
||||||
|
worker_name=WORKER_NAME,
|
||||||
|
slots=SLOTS,
|
||||||
|
labels={"pool": POOL},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_worker = hatchet.worker(
|
||||||
|
WORKER_NAME,
|
||||||
|
slots=SLOTS, # not all slots are probably used
|
||||||
|
labels={
|
||||||
|
"pool": POOL,
|
||||||
|
},
|
||||||
|
workflows=[
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
topic_chunk_workflow,
|
||||||
|
subject_workflow,
|
||||||
|
track_workflow,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_worker.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received shutdown signal, stopping LLM workers...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -23,7 +23,12 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Coroutine, Protocol, TypeVar
|
from typing import Any, Callable, Coroutine, Protocol, TypeVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from hatchet_sdk import Context
|
from hatchet_sdk import (
|
||||||
|
ConcurrencyExpression,
|
||||||
|
ConcurrencyLimitStrategy,
|
||||||
|
Context,
|
||||||
|
)
|
||||||
|
from hatchet_sdk.labels import DesiredWorkerLabel
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
@@ -467,10 +472,24 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
parents=[process_tracks],
|
parents=[process_tracks],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
desired_worker_labels={
|
||||||
|
"pool": DesiredWorkerLabel(
|
||||||
|
value="cpu-heavy",
|
||||||
|
required=True,
|
||||||
|
weight=100,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
concurrency=[
|
||||||
|
ConcurrencyExpression(
|
||||||
|
expression="'mixdown-global'",
|
||||||
|
max_runs=1, # serialize mixdown to prevent resource contention
|
||||||
|
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
||||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||||
"""Mix all padded tracks into single audio file using PyAV (same as Celery)."""
|
"""Mix all padded tracks into single audio file using PyAV or Modal backend."""
|
||||||
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
|
||||||
|
|
||||||
track_result = ctx.task_output(process_tracks)
|
track_result = ctx.task_output(process_tracks)
|
||||||
@@ -494,7 +513,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
|
|
||||||
storage = _spawn_storage()
|
storage = _spawn_storage()
|
||||||
|
|
||||||
# Presign URLs on demand (avoids stale URLs on workflow replay)
|
# Presign URLs for padded tracks (same expiration for both backends)
|
||||||
padded_urls = []
|
padded_urls = []
|
||||||
for track_info in padded_tracks:
|
for track_info in padded_tracks:
|
||||||
if track_info.key:
|
if track_info.key:
|
||||||
@@ -515,33 +534,104 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
logger.error("Mixdown failed - no decodable audio frames found")
|
logger.error("Mixdown failed - no decodable audio frames found")
|
||||||
raise ValueError("No decodable audio frames in any track")
|
raise ValueError("No decodable audio frames in any track")
|
||||||
|
|
||||||
output_path = tempfile.mktemp(suffix=".mp3")
|
output_key = f"{input.transcript_id}/audio.mp3"
|
||||||
duration_ms_callback_capture_container = [0.0]
|
|
||||||
|
|
||||||
async def capture_duration(d):
|
# Conditional: Modal or local backend
|
||||||
duration_ms_callback_capture_container[0] = d
|
if settings.MIXDOWN_BACKEND == "modal":
|
||||||
|
ctx.log("mixdown_tracks: using Modal backend")
|
||||||
|
|
||||||
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration)
|
# Presign PUT URL for output (Modal will upload directly)
|
||||||
|
output_url = await storage.get_file_url(
|
||||||
|
output_key,
|
||||||
|
operation="put_object",
|
||||||
|
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
await mixdown_tracks_pyav(
|
from reflector.processors.audio_mixdown_modal import ( # noqa: PLC0415
|
||||||
valid_urls,
|
AudioMixdownModalProcessor,
|
||||||
writer,
|
)
|
||||||
target_sample_rate,
|
|
||||||
offsets_seconds=None,
|
|
||||||
logger=logger,
|
|
||||||
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
|
|
||||||
expected_duration_sec=recording_duration if recording_duration > 0 else None,
|
|
||||||
)
|
|
||||||
await writer.flush()
|
|
||||||
|
|
||||||
file_size = Path(output_path).stat().st_size
|
try:
|
||||||
storage_path = f"{input.transcript_id}/audio.mp3"
|
processor = AudioMixdownModalProcessor()
|
||||||
|
result = await processor.mixdown(
|
||||||
|
track_urls=valid_urls,
|
||||||
|
output_url=output_url,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
expected_duration_sec=recording_duration
|
||||||
|
if recording_duration > 0
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
duration_ms = result.duration_ms
|
||||||
|
tracks_mixed = result.tracks_mixed
|
||||||
|
|
||||||
with open(output_path, "rb") as mixed_file:
|
ctx.log(
|
||||||
await storage.put_file(storage_path, mixed_file)
|
f"mixdown_tracks: Modal returned duration={duration_ms}ms, tracks={tracks_mixed}"
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_detail = e.response.text if hasattr(e.response, "text") else str(e)
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] Modal mixdown HTTP error",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
status_code=e.response.status_code if hasattr(e, "response") else None,
|
||||||
|
error=error_detail,
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Modal mixdown failed with HTTP {e.response.status_code}: {error_detail}"
|
||||||
|
)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] Modal mixdown timeout",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
timeout=settings.MIXDOWN_TIMEOUT,
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Modal mixdown timeout after {settings.MIXDOWN_TIMEOUT}s"
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(
|
||||||
|
"[Hatchet] Modal mixdown validation error",
|
||||||
|
transcript_id=input.transcript_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
ctx.log("mixdown_tracks: using local backend")
|
||||||
|
|
||||||
Path(output_path).unlink(missing_ok=True)
|
# Existing local implementation
|
||||||
|
output_path = tempfile.mktemp(suffix=".mp3")
|
||||||
|
duration_ms_callback_capture_container = [0.0]
|
||||||
|
|
||||||
|
async def capture_duration(d):
|
||||||
|
duration_ms_callback_capture_container[0] = d
|
||||||
|
|
||||||
|
writer = AudioFileWriterProcessor(
|
||||||
|
path=output_path, on_duration=capture_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
await mixdown_tracks_pyav(
|
||||||
|
valid_urls,
|
||||||
|
writer,
|
||||||
|
target_sample_rate,
|
||||||
|
offsets_seconds=None,
|
||||||
|
logger=logger,
|
||||||
|
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
|
||||||
|
expected_duration_sec=recording_duration
|
||||||
|
if recording_duration > 0
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
await writer.flush()
|
||||||
|
|
||||||
|
file_size = Path(output_path).stat().st_size
|
||||||
|
with open(output_path, "rb") as mixed_file:
|
||||||
|
await storage.put_file(output_key, mixed_file)
|
||||||
|
|
||||||
|
Path(output_path).unlink(missing_ok=True)
|
||||||
|
duration_ms = duration_ms_callback_capture_container[0]
|
||||||
|
tracks_mixed = len(valid_urls)
|
||||||
|
|
||||||
|
ctx.log(f"mixdown_tracks: local mixdown uploaded {file_size} bytes")
|
||||||
|
|
||||||
|
# Update DB (same for both backends)
|
||||||
async with fresh_db_connection():
|
async with fresh_db_connection():
|
||||||
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
|
||||||
|
|
||||||
@@ -551,12 +641,12 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
transcript, {"audio_location": "storage"}
|
transcript, {"audio_location": "storage"}
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.log(f"mixdown_tracks complete: uploaded {file_size} bytes to {storage_path}")
|
ctx.log(f"mixdown_tracks complete: uploaded to {output_key}")
|
||||||
|
|
||||||
return MixdownResult(
|
return MixdownResult(
|
||||||
audio_key=storage_path,
|
audio_key=output_key,
|
||||||
duration=duration_ms_callback_capture_container[0],
|
duration=duration_ms,
|
||||||
tracks_mixed=len(valid_urls),
|
tracks_mixed=tracks_mixed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,11 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context
|
from hatchet_sdk import (
|
||||||
|
ConcurrencyExpression,
|
||||||
|
ConcurrencyLimitStrategy,
|
||||||
|
Context,
|
||||||
|
)
|
||||||
from hatchet_sdk.rate_limit import RateLimit
|
from hatchet_sdk.rate_limit import RateLimit
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -34,11 +38,13 @@ hatchet = HatchetClientManager.get_client()
|
|||||||
topic_chunk_workflow = hatchet.workflow(
|
topic_chunk_workflow = hatchet.workflow(
|
||||||
name="TopicChunkProcessing",
|
name="TopicChunkProcessing",
|
||||||
input_validator=TopicChunkInput,
|
input_validator=TopicChunkInput,
|
||||||
concurrency=ConcurrencyExpression(
|
concurrency=[
|
||||||
expression="'global'", # constant string = global limit across all runs
|
ConcurrencyExpression(
|
||||||
max_runs=20,
|
expression="'global'", # constant string = global limit across all runs
|
||||||
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
|
max_runs=20,
|
||||||
),
|
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
89
server/reflector/processors/audio_mixdown_modal.py
Normal file
89
server/reflector/processors/audio_mixdown_modal.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
Modal.com backend for audio mixdown.
|
||||||
|
|
||||||
|
Uses Modal's CPU containers to offload audio mixing from Hatchet workers.
|
||||||
|
Communicates via presigned S3 URLs for both input and output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class MixdownResponse(BaseModel):
|
||||||
|
"""Response from Modal mixdown endpoint."""
|
||||||
|
|
||||||
|
duration_ms: float
|
||||||
|
tracks_mixed: int
|
||||||
|
audio_uploaded: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AudioMixdownModalProcessor:
|
||||||
|
"""Audio mixdown processor using Modal.com CPU backend.
|
||||||
|
|
||||||
|
Sends track URLs (presigned GET) and output URL (presigned PUT) to Modal.
|
||||||
|
Modal handles download, mixdown via PyAV, and upload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, modal_api_key: str | None = None):
|
||||||
|
if not settings.MIXDOWN_URL:
|
||||||
|
raise ValueError("MIXDOWN_URL required to use AudioMixdownModalProcessor")
|
||||||
|
|
||||||
|
self.mixdown_url = settings.MIXDOWN_URL + "/v1"
|
||||||
|
self.timeout = settings.MIXDOWN_TIMEOUT
|
||||||
|
self.modal_api_key = modal_api_key or settings.MIXDOWN_MODAL_API_KEY
|
||||||
|
|
||||||
|
if not self.modal_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"MIXDOWN_MODAL_API_KEY required to use AudioMixdownModalProcessor"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mixdown(
|
||||||
|
self,
|
||||||
|
track_urls: list[str],
|
||||||
|
output_url: str,
|
||||||
|
target_sample_rate: int,
|
||||||
|
expected_duration_sec: float | None = None,
|
||||||
|
) -> MixdownResponse:
|
||||||
|
"""Mix multiple audio tracks via Modal backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
track_urls: List of presigned GET URLs for audio tracks (non-empty)
|
||||||
|
output_url: Presigned PUT URL for output MP3
|
||||||
|
target_sample_rate: Sample rate for output (Hz, must be positive)
|
||||||
|
expected_duration_sec: Optional fallback duration if container metadata unavailable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MixdownResponse with duration_ms, tracks_mixed, audio_uploaded
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If track_urls is empty or target_sample_rate invalid
|
||||||
|
httpx.HTTPStatusError: On HTTP errors (404, 403, 500, etc.)
|
||||||
|
httpx.TimeoutException: On timeout
|
||||||
|
"""
|
||||||
|
# Validate inputs
|
||||||
|
if not track_urls:
|
||||||
|
raise ValueError("track_urls cannot be empty")
|
||||||
|
if target_sample_rate <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"target_sample_rate must be positive, got {target_sample_rate}"
|
||||||
|
)
|
||||||
|
if expected_duration_sec is not None and expected_duration_sec < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"expected_duration_sec cannot be negative, got {expected_duration_sec}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.mixdown_url}/audio/mixdown",
|
||||||
|
headers={"Authorization": f"Bearer {self.modal_api_key}"},
|
||||||
|
json={
|
||||||
|
"track_urls": track_urls,
|
||||||
|
"output_url": output_url,
|
||||||
|
"target_sample_rate": target_sample_rate,
|
||||||
|
"expected_duration_sec": expected_duration_sec,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return MixdownResponse(**response.json())
|
||||||
@@ -98,6 +98,17 @@ class Settings(BaseSettings):
|
|||||||
# Diarization: local pyannote.audio
|
# Diarization: local pyannote.audio
|
||||||
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
|
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
|
||||||
|
|
||||||
|
# Audio Mixdown
|
||||||
|
# backends:
|
||||||
|
# - local: in-process PyAV mixdown (runs in same process as Hatchet worker)
|
||||||
|
# - modal: HTTP API client to Modal.com CPU container
|
||||||
|
MIXDOWN_BACKEND: str = "local"
|
||||||
|
MIXDOWN_URL: str | None = None
|
||||||
|
MIXDOWN_TIMEOUT: int = 900 # 15 minutes
|
||||||
|
|
||||||
|
# Mixdown: modal backend
|
||||||
|
MIXDOWN_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
# Sentry
|
# Sentry
|
||||||
SENTRY_DSN: str | None = None
|
SENTRY_DSN: str | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""
|
|
||||||
Transcripts chat API
|
|
||||||
====================
|
|
||||||
|
|
||||||
WebSocket endpoint for bidirectional chat with LLM about transcript content.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
||||||
from llama_index.core import Settings
|
|
||||||
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
|
||||||
|
|
||||||
from reflector.auth.auth_jwt import JWTAuth
|
|
||||||
from reflector.db.recordings import recordings_controller
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
|
||||||
from reflector.db.users import user_controller
|
|
||||||
from reflector.llm import LLM
|
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.transcript_formats import topics_to_webvtt_named
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_is_multitrack(transcript) -> bool:
|
|
||||||
"""Detect if transcript is from multitrack recording."""
|
|
||||||
if not transcript.recording_id:
|
|
||||||
return False
|
|
||||||
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
|
||||||
return recording is not None and recording.is_multitrack
|
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/transcripts/{transcript_id}/chat")
|
|
||||||
async def transcript_chat_websocket(
|
|
||||||
transcript_id: str,
|
|
||||||
websocket: WebSocket,
|
|
||||||
):
|
|
||||||
"""WebSocket endpoint for chatting with LLM about transcript content."""
|
|
||||||
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
|
|
||||||
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
|
|
||||||
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
|
|
||||||
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
|
|
||||||
token: Optional[str] = None
|
|
||||||
negotiated_subprotocol: Optional[str] = None
|
|
||||||
if len(parts) >= 2 and parts[0].lower() == "bearer":
|
|
||||||
negotiated_subprotocol = "bearer"
|
|
||||||
token = parts[1]
|
|
||||||
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
if token:
|
|
||||||
try:
|
|
||||||
payload = JWTAuth().verify_token(token)
|
|
||||||
authentik_uid = payload.get("sub")
|
|
||||||
|
|
||||||
if authentik_uid:
|
|
||||||
user = await user_controller.get_by_authentik_uid(authentik_uid)
|
|
||||||
if user:
|
|
||||||
user_id = user.id
|
|
||||||
except Exception:
|
|
||||||
# Auth failed - continue as anonymous
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Get transcript (respects user_id for private transcripts)
|
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
|
||||||
transcript_id, user_id=user_id
|
|
||||||
)
|
|
||||||
if not transcript:
|
|
||||||
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 2. Accept connection (with negotiated subprotocol if present)
|
|
||||||
await websocket.accept(subprotocol=negotiated_subprotocol)
|
|
||||||
|
|
||||||
# 3. Generate WebVTT context
|
|
||||||
is_multitrack = await _get_is_multitrack(transcript)
|
|
||||||
webvtt = topics_to_webvtt_named(
|
|
||||||
transcript.topics, transcript.participants, is_multitrack
|
|
||||||
)
|
|
||||||
|
|
||||||
# Truncate if needed (15k char limit for POC)
|
|
||||||
webvtt_truncated = webvtt[:15000] if len(webvtt) > 15000 else webvtt
|
|
||||||
|
|
||||||
# 4. Configure LLM
|
|
||||||
llm = LLM(settings=settings, temperature=0.7)
|
|
||||||
|
|
||||||
# 5. System message with transcript context
|
|
||||||
system_msg = f"""You are analyzing this meeting transcript (WebVTT):
|
|
||||||
|
|
||||||
{webvtt_truncated}
|
|
||||||
|
|
||||||
Answer questions about content, speakers, timeline. Include timestamps when relevant."""
|
|
||||||
|
|
||||||
# 6. Conversation history
|
|
||||||
conversation_history = [ChatMessage(role=MessageRole.SYSTEM, content=system_msg)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 7. Message loop
|
|
||||||
while True:
|
|
||||||
data = await websocket.receive_json()
|
|
||||||
|
|
||||||
if data.get("type") == "get_context":
|
|
||||||
# Return WebVTT context (for debugging/testing)
|
|
||||||
await websocket.send_json({"type": "context", "webvtt": webvtt})
|
|
||||||
continue
|
|
||||||
|
|
||||||
if data.get("type") != "message":
|
|
||||||
# Echo unknown types for backward compatibility
|
|
||||||
await websocket.send_json({"type": "echo", "data": data})
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Add user message to history
|
|
||||||
user_msg = ChatMessage(role=MessageRole.USER, content=data.get("text", ""))
|
|
||||||
conversation_history.append(user_msg)
|
|
||||||
|
|
||||||
# Stream LLM response
|
|
||||||
assistant_msg = ""
|
|
||||||
chat_stream = await Settings.llm.astream_chat(conversation_history)
|
|
||||||
async for chunk in chat_stream:
|
|
||||||
token = chunk.delta or ""
|
|
||||||
if token:
|
|
||||||
await websocket.send_json({"type": "token", "text": token})
|
|
||||||
assistant_msg += token
|
|
||||||
|
|
||||||
# Save assistant response to history
|
|
||||||
conversation_history.append(
|
|
||||||
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_msg)
|
|
||||||
)
|
|
||||||
await websocket.send_json({"type": "done"})
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
await websocket.send_json({"type": "error", "message": str(e)})
|
|
||||||
@@ -7,8 +7,10 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then
|
|||||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||||
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
||||||
uv run celery -A reflector.worker.app beat --loglevel=info
|
uv run celery -A reflector.worker.app beat --loglevel=info
|
||||||
elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then
|
elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then
|
||||||
uv run python -m reflector.hatchet.run_workers
|
uv run python -m reflector.hatchet.run_workers_cpu
|
||||||
|
elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then
|
||||||
|
uv run python -m reflector.hatchet.run_workers_llm
|
||||||
else
|
else
|
||||||
echo "Unknown command"
|
echo "Unknown command"
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -1,234 +0,0 @@
|
|||||||
"""Tests for transcript chat WebSocket endpoint."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from httpx_ws import aconnect_ws
|
|
||||||
from uvicorn import Config, Server
|
|
||||||
|
|
||||||
from reflector.db.transcripts import (
|
|
||||||
SourceKind,
|
|
||||||
TranscriptParticipant,
|
|
||||||
TranscriptTopic,
|
|
||||||
transcripts_controller,
|
|
||||||
)
|
|
||||||
from reflector.processors.types import Word
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def chat_appserver(tmpdir, setup_database):
|
|
||||||
"""Start a real HTTP server for WebSocket testing."""
|
|
||||||
from reflector.app import app
|
|
||||||
from reflector.db import get_database
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
DATA_DIR = settings.DATA_DIR
|
|
||||||
settings.DATA_DIR = Path(tmpdir)
|
|
||||||
|
|
||||||
# Start server in separate thread with its own event loop
|
|
||||||
host = "127.0.0.1"
|
|
||||||
port = 1256 # Different port from rtc tests
|
|
||||||
server_started = threading.Event()
|
|
||||||
server_exception = None
|
|
||||||
server_instance = None
|
|
||||||
|
|
||||||
def run_server():
|
|
||||||
nonlocal server_exception, server_instance
|
|
||||||
try:
|
|
||||||
# Create new event loop for this thread
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
config = Config(app=app, host=host, port=port, loop=loop)
|
|
||||||
server_instance = Server(config)
|
|
||||||
|
|
||||||
async def start_server():
|
|
||||||
# Initialize database connection in this event loop
|
|
||||||
database = get_database()
|
|
||||||
await database.connect()
|
|
||||||
try:
|
|
||||||
await server_instance.serve()
|
|
||||||
finally:
|
|
||||||
await database.disconnect()
|
|
||||||
|
|
||||||
# Signal that server is starting
|
|
||||||
server_started.set()
|
|
||||||
loop.run_until_complete(start_server())
|
|
||||||
except Exception as e:
|
|
||||||
server_exception = e
|
|
||||||
server_started.set()
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
|
||||||
server_thread.start()
|
|
||||||
|
|
||||||
# Wait for server to start
|
|
||||||
server_started.wait(timeout=30)
|
|
||||||
if server_exception:
|
|
||||||
raise server_exception
|
|
||||||
|
|
||||||
# Wait for server to be fully ready
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
yield server_instance, host, port
|
|
||||||
|
|
||||||
# Stop server
|
|
||||||
if server_instance:
|
|
||||||
server_instance.should_exit = True
|
|
||||||
server_thread.join(timeout=30)
|
|
||||||
|
|
||||||
settings.DATA_DIR = DATA_DIR
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def test_transcript(setup_database):
|
|
||||||
"""Create a test transcript for WebSocket tests."""
|
|
||||||
transcript = await transcripts_controller.add(
|
|
||||||
name="Test Transcript for Chat", source_kind=SourceKind.FILE
|
|
||||||
)
|
|
||||||
return transcript
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def test_transcript_with_content(setup_database):
|
|
||||||
"""Create a test transcript with actual content for WebVTT generation."""
|
|
||||||
transcript = await transcripts_controller.add(
|
|
||||||
name="Test Transcript with Content", source_kind=SourceKind.FILE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add participants
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"participants": [
|
|
||||||
TranscriptParticipant(id="1", speaker=0, name="Alice").model_dump(),
|
|
||||||
TranscriptParticipant(id="2", speaker=1, name="Bob").model_dump(),
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add topic with words
|
|
||||||
await transcripts_controller.upsert_topic(
|
|
||||||
transcript,
|
|
||||||
TranscriptTopic(
|
|
||||||
title="Introduction",
|
|
||||||
summary="Opening remarks",
|
|
||||||
timestamp=0.0,
|
|
||||||
words=[
|
|
||||||
Word(text="Hello ", start=0.0, end=1.0, speaker=0),
|
|
||||||
Word(text="everyone.", start=1.0, end=2.0, speaker=0),
|
|
||||||
Word(text="Hi ", start=2.0, end=3.0, speaker=1),
|
|
||||||
Word(text="there!", start=3.0, end=4.0, speaker=1),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return transcript
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_websocket_connection_success(test_transcript, chat_appserver):
|
|
||||||
"""Test successful WebSocket connection to chat endpoint."""
|
|
||||||
server, host, port = chat_appserver
|
|
||||||
base_url = f"ws://{host}:{port}/v1"
|
|
||||||
|
|
||||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
|
||||||
# Send unknown message type to test echo behavior
|
|
||||||
await ws.send_json({"type": "test", "text": "Hello"})
|
|
||||||
|
|
||||||
# Should receive echo for unknown types
|
|
||||||
response = await ws.receive_json()
|
|
||||||
assert response["type"] == "echo"
|
|
||||||
assert response["data"]["type"] == "test"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_websocket_nonexistent_transcript(chat_appserver):
|
|
||||||
"""Test WebSocket connection fails for nonexistent transcript."""
|
|
||||||
server, host, port = chat_appserver
|
|
||||||
base_url = f"ws://{host}:{port}/v1"
|
|
||||||
|
|
||||||
# Connection should fail or disconnect immediately for non-existent transcript
|
|
||||||
# Different behavior from successful connection
|
|
||||||
with pytest.raises(Exception): # Will raise on connection or first operation
|
|
||||||
async with aconnect_ws(f"{base_url}/transcripts/nonexistent-id/chat") as ws:
|
|
||||||
await ws.send_json({"type": "message", "text": "Hello"})
|
|
||||||
await ws.receive_json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_websocket_multiple_messages(test_transcript, chat_appserver):
|
|
||||||
"""Test sending multiple messages through WebSocket."""
|
|
||||||
server, host, port = chat_appserver
|
|
||||||
base_url = f"ws://{host}:{port}/v1"
|
|
||||||
|
|
||||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
|
||||||
# Send multiple unknown message types (testing echo behavior)
|
|
||||||
messages = ["First message", "Second message", "Third message"]
|
|
||||||
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
await ws.send_json({"type": f"test{i}", "text": msg})
|
|
||||||
response = await ws.receive_json()
|
|
||||||
assert response["type"] == "echo"
|
|
||||||
assert response["data"]["type"] == f"test{i}"
|
|
||||||
assert response["data"]["text"] == msg
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_websocket_disconnect_graceful(test_transcript, chat_appserver):
|
|
||||||
"""Test WebSocket disconnects gracefully."""
|
|
||||||
server, host, port = chat_appserver
|
|
||||||
base_url = f"ws://{host}:{port}/v1"
|
|
||||||
|
|
||||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
|
||||||
await ws.send_json({"type": "message", "text": "Hello"})
|
|
||||||
await ws.receive_json()
|
|
||||||
# Close handled by context manager - should not raise
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_websocket_context_generation(
|
|
||||||
test_transcript_with_content, chat_appserver
|
|
||||||
):
|
|
||||||
"""Test WebVTT context is generated on connection."""
|
|
||||||
server, host, port = chat_appserver
|
|
||||||
base_url = f"ws://{host}:{port}/v1"
|
|
||||||
|
|
||||||
async with aconnect_ws(
|
|
||||||
f"{base_url}/transcripts/{test_transcript_with_content.id}/chat"
|
|
||||||
) as ws:
|
|
||||||
# Request context
|
|
||||||
await ws.send_json({"type": "get_context"})
|
|
||||||
|
|
||||||
# Receive context response
|
|
||||||
response = await ws.receive_json()
|
|
||||||
assert response["type"] == "context"
|
|
||||||
assert "webvtt" in response
|
|
||||||
|
|
||||||
# Verify WebVTT format
|
|
||||||
webvtt = response["webvtt"]
|
|
||||||
assert webvtt.startswith("WEBVTT")
|
|
||||||
assert "<v Alice>" in webvtt
|
|
||||||
assert "<v Bob>" in webvtt
|
|
||||||
assert "Hello everyone." in webvtt
|
|
||||||
assert "Hi there!" in webvtt
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_websocket_unknown_message_type(test_transcript, chat_appserver):
|
|
||||||
"""Test unknown message types are echoed back."""
|
|
||||||
server, host, port = chat_appserver
|
|
||||||
base_url = f"ws://{host}:{port}/v1"
|
|
||||||
|
|
||||||
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
|
||||||
# Send unknown message type
|
|
||||||
await ws.send_json({"type": "unknown", "data": "test"})
|
|
||||||
|
|
||||||
# Should receive echo
|
|
||||||
response = await ws.receive_json()
|
|
||||||
assert response["type"] == "echo"
|
|
||||||
assert response["data"]["type"] == "unknown"
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { useState } from "react";
|
|
||||||
import { Box, Dialog, Input, IconButton } from "@chakra-ui/react";
|
|
||||||
import { MessageCircle } from "lucide-react";
|
|
||||||
import Markdown from "react-markdown";
|
|
||||||
import "../../styles/markdown.css";
|
|
||||||
import type { Message } from "./useTranscriptChat";
|
|
||||||
|
|
||||||
interface TranscriptChatModalProps {
|
|
||||||
open: boolean;
|
|
||||||
onClose: () => void;
|
|
||||||
messages: Message[];
|
|
||||||
sendMessage: (text: string) => void;
|
|
||||||
isStreaming: boolean;
|
|
||||||
currentStreamingText: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function TranscriptChatModal({
|
|
||||||
open,
|
|
||||||
onClose,
|
|
||||||
messages,
|
|
||||||
sendMessage,
|
|
||||||
isStreaming,
|
|
||||||
currentStreamingText,
|
|
||||||
}: TranscriptChatModalProps) {
|
|
||||||
const [input, setInput] = useState("");
|
|
||||||
|
|
||||||
const handleSend = () => {
|
|
||||||
if (!input.trim()) return;
|
|
||||||
sendMessage(input);
|
|
||||||
setInput("");
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog.Root open={open} onOpenChange={(e) => !e.open && onClose()}>
|
|
||||||
<Dialog.Backdrop />
|
|
||||||
<Dialog.Positioner>
|
|
||||||
<Dialog.Content maxW="500px" h="600px">
|
|
||||||
<Dialog.Header>Transcript Chat</Dialog.Header>
|
|
||||||
|
|
||||||
<Dialog.Body overflowY="auto">
|
|
||||||
{messages.map((msg) => (
|
|
||||||
<Box
|
|
||||||
key={msg.id}
|
|
||||||
p={3}
|
|
||||||
mb={2}
|
|
||||||
bg={msg.role === "user" ? "blue.50" : "gray.50"}
|
|
||||||
borderRadius="md"
|
|
||||||
>
|
|
||||||
{msg.role === "user" ? (
|
|
||||||
msg.text
|
|
||||||
) : (
|
|
||||||
<div className="markdown">
|
|
||||||
<Markdown>{msg.text}</Markdown>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</Box>
|
|
||||||
))}
|
|
||||||
|
|
||||||
{isStreaming && (
|
|
||||||
<Box p={3} bg="gray.50" borderRadius="md">
|
|
||||||
<div className="markdown">
|
|
||||||
<Markdown>{currentStreamingText}</Markdown>
|
|
||||||
</div>
|
|
||||||
<Box as="span" className="animate-pulse">
|
|
||||||
▊
|
|
||||||
</Box>
|
|
||||||
</Box>
|
|
||||||
)}
|
|
||||||
</Dialog.Body>
|
|
||||||
|
|
||||||
<Dialog.Footer>
|
|
||||||
<Input
|
|
||||||
value={input}
|
|
||||||
onChange={(e) => setInput(e.target.value)}
|
|
||||||
onKeyDown={(e) => e.key === "Enter" && handleSend()}
|
|
||||||
placeholder="Ask about transcript..."
|
|
||||||
disabled={isStreaming}
|
|
||||||
/>
|
|
||||||
</Dialog.Footer>
|
|
||||||
</Dialog.Content>
|
|
||||||
</Dialog.Positioner>
|
|
||||||
</Dialog.Root>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function TranscriptChatButton({ onClick }: { onClick: () => void }) {
|
|
||||||
return (
|
|
||||||
<IconButton
|
|
||||||
position="fixed"
|
|
||||||
bottom="24px"
|
|
||||||
right="24px"
|
|
||||||
onClick={onClick}
|
|
||||||
size="lg"
|
|
||||||
colorPalette="blue"
|
|
||||||
borderRadius="full"
|
|
||||||
aria-label="Open chat"
|
|
||||||
>
|
|
||||||
<MessageCircle />
|
|
||||||
</IconButton>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -18,15 +18,9 @@ import {
|
|||||||
Skeleton,
|
Skeleton,
|
||||||
Text,
|
Text,
|
||||||
Spinner,
|
Spinner,
|
||||||
useDisclosure,
|
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { useTranscriptGet } from "../../../lib/apiHooks";
|
import { useTranscriptGet } from "../../../lib/apiHooks";
|
||||||
import { TranscriptStatus } from "../../../lib/transcript";
|
import { TranscriptStatus } from "../../../lib/transcript";
|
||||||
import {
|
|
||||||
TranscriptChatModal,
|
|
||||||
TranscriptChatButton,
|
|
||||||
} from "../TranscriptChatModal";
|
|
||||||
import { useTranscriptChat } from "../useTranscriptChat";
|
|
||||||
|
|
||||||
type TranscriptDetails = {
|
type TranscriptDetails = {
|
||||||
params: Promise<{
|
params: Promise<{
|
||||||
@@ -59,9 +53,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
|||||||
const [finalSummaryElement, setFinalSummaryElement] =
|
const [finalSummaryElement, setFinalSummaryElement] =
|
||||||
useState<HTMLDivElement | null>(null);
|
useState<HTMLDivElement | null>(null);
|
||||||
|
|
||||||
const { open, onOpen, onClose } = useDisclosure();
|
|
||||||
const chat = useTranscriptChat(transcriptId);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!waiting || !transcript.data) return;
|
if (!waiting || !transcript.data) return;
|
||||||
|
|
||||||
@@ -128,15 +119,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<TranscriptChatModal
|
|
||||||
open={open}
|
|
||||||
onClose={onClose}
|
|
||||||
messages={chat.messages}
|
|
||||||
sendMessage={chat.sendMessage}
|
|
||||||
isStreaming={chat.isStreaming}
|
|
||||||
currentStreamingText={chat.currentStreamingText}
|
|
||||||
/>
|
|
||||||
<TranscriptChatButton onClick={onOpen} />
|
|
||||||
<Grid
|
<Grid
|
||||||
templateColumns="1fr"
|
templateColumns="1fr"
|
||||||
templateRows="auto minmax(0, 1fr)"
|
templateRows="auto minmax(0, 1fr)"
|
||||||
|
|||||||
@@ -1,130 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { useEffect, useState, useRef } from "react";
|
|
||||||
import { getSession } from "next-auth/react";
|
|
||||||
import { WEBSOCKET_URL } from "../../lib/apiClient";
|
|
||||||
import { assertExtendedToken } from "../../lib/types";
|
|
||||||
|
|
||||||
export type Message = {
|
|
||||||
id: string;
|
|
||||||
role: "user" | "assistant";
|
|
||||||
text: string;
|
|
||||||
timestamp: Date;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type UseTranscriptChat = {
|
|
||||||
messages: Message[];
|
|
||||||
sendMessage: (text: string) => void;
|
|
||||||
isStreaming: boolean;
|
|
||||||
currentStreamingText: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const useTranscriptChat = (transcriptId: string): UseTranscriptChat => {
|
|
||||||
const [messages, setMessages] = useState<Message[]>([]);
|
|
||||||
const [isStreaming, setIsStreaming] = useState(false);
|
|
||||||
const [currentStreamingText, setCurrentStreamingText] = useState("");
|
|
||||||
const wsRef = useRef<WebSocket | null>(null);
|
|
||||||
const streamingTextRef = useRef<string>("");
|
|
||||||
const isMountedRef = useRef<boolean>(true);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
isMountedRef.current = true;
|
|
||||||
|
|
||||||
const connectWebSocket = async () => {
|
|
||||||
const url = `${WEBSOCKET_URL}/v1/transcripts/${transcriptId}/chat`;
|
|
||||||
|
|
||||||
// Get auth token for WebSocket subprotocol
|
|
||||||
let protocols: string[] | undefined;
|
|
||||||
try {
|
|
||||||
const session = await getSession();
|
|
||||||
if (session) {
|
|
||||||
const token = assertExtendedToken(session).accessToken;
|
|
||||||
// Pass token via subprotocol: ["bearer", token]
|
|
||||||
protocols = ["bearer", token];
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.warn("Failed to get auth token for WebSocket:", error);
|
|
||||||
}
|
|
||||||
|
|
||||||
const ws = new WebSocket(url, protocols);
|
|
||||||
wsRef.current = ws;
|
|
||||||
|
|
||||||
ws.onopen = () => {
|
|
||||||
console.log("Chat WebSocket connected");
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
|
||||||
if (!isMountedRef.current) return;
|
|
||||||
|
|
||||||
const msg = JSON.parse(event.data);
|
|
||||||
|
|
||||||
switch (msg.type) {
|
|
||||||
case "token":
|
|
||||||
setIsStreaming(true);
|
|
||||||
streamingTextRef.current += msg.text;
|
|
||||||
setCurrentStreamingText(streamingTextRef.current);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "done":
|
|
||||||
// CRITICAL: Save the text BEFORE resetting the ref
|
|
||||||
// The setMessages callback may execute later, after ref is reset
|
|
||||||
const finalText = streamingTextRef.current;
|
|
||||||
|
|
||||||
setMessages((prev) => [
|
|
||||||
...prev,
|
|
||||||
{
|
|
||||||
id: Date.now().toString(),
|
|
||||||
role: "assistant",
|
|
||||||
text: finalText,
|
|
||||||
timestamp: new Date(),
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
streamingTextRef.current = "";
|
|
||||||
setCurrentStreamingText("");
|
|
||||||
setIsStreaming(false);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "error":
|
|
||||||
console.error("Chat error:", msg.message);
|
|
||||||
setIsStreaming(false);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.onerror = (error) => {
|
|
||||||
console.error("WebSocket error:", error);
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.onclose = () => {
|
|
||||||
console.log("Chat WebSocket closed");
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
connectWebSocket();
|
|
||||||
|
|
||||||
return () => {
|
|
||||||
isMountedRef.current = false;
|
|
||||||
if (wsRef.current) {
|
|
||||||
wsRef.current.close();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}, [transcriptId]);
|
|
||||||
|
|
||||||
const sendMessage = (text: string) => {
|
|
||||||
if (!wsRef.current) return;
|
|
||||||
|
|
||||||
setMessages((prev) => [
|
|
||||||
...prev,
|
|
||||||
{
|
|
||||||
id: Date.now().toString(),
|
|
||||||
role: "user",
|
|
||||||
text,
|
|
||||||
timestamp: new Date(),
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
|
|
||||||
wsRef.current.send(JSON.stringify({ type: "message", text }));
|
|
||||||
};
|
|
||||||
|
|
||||||
return { messages, sendMessage, isStreaming, currentStreamingText };
|
|
||||||
};
|
|
||||||
Reference in New Issue
Block a user