Compare commits

..

1 Commits

Author SHA1 Message Date
Igor Loskutov
e1b790c5a8 Add Modal backend for audio mixdown 2026-01-21 17:06:17 -05:00
16 changed files with 769 additions and 276 deletions

View File

@@ -1,19 +1,5 @@
# Changelog # Changelog
## [0.29.0](https://github.com/Monadical-SAS/reflector/compare/v0.28.1...v0.29.0) (2026-01-21)
### Features
* set hatchet as default for multitracks ([#822](https://github.com/Monadical-SAS/reflector/issues/822)) ([c723752](https://github.com/Monadical-SAS/reflector/commit/c723752b7e15aa48a41ad22856f147a5517d3f46))
## [0.28.1](https://github.com/Monadical-SAS/reflector/compare/v0.28.0...v0.28.1) (2026-01-21)
### Bug Fixes
* ics non-sync bugfix ([#823](https://github.com/Monadical-SAS/reflector/issues/823)) ([23d2bc2](https://github.com/Monadical-SAS/reflector/commit/23d2bc283d4d02187b250d2055103e0374ee93d6))
## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20) ## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20)

View File

@@ -131,6 +131,15 @@ if [ -z "$DIARIZER_URL" ]; then
fi fi
echo " -> $DIARIZER_URL" echo " -> $DIARIZER_URL"
echo ""
echo "Deploying mixdown (CPU audio processing)..."
MIXDOWN_URL=$(modal deploy reflector_mixdown.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
if [ -z "$MIXDOWN_URL" ]; then
echo "Error: Failed to deploy mixdown. Check Modal dashboard for details."
exit 1
fi
echo " -> $MIXDOWN_URL"
# --- Output Configuration --- # --- Output Configuration ---
echo "" echo ""
echo "==========================================" echo "=========================================="
@@ -147,4 +156,8 @@ echo ""
echo "DIARIZATION_BACKEND=modal" echo "DIARIZATION_BACKEND=modal"
echo "DIARIZATION_URL=$DIARIZER_URL" echo "DIARIZATION_URL=$DIARIZER_URL"
echo "DIARIZATION_MODAL_API_KEY=$API_KEY" echo "DIARIZATION_MODAL_API_KEY=$API_KEY"
echo ""
echo "MIXDOWN_BACKEND=modal"
echo "MIXDOWN_URL=$MIXDOWN_URL"
echo "MIXDOWN_MODAL_API_KEY=$API_KEY"
echo "# --- End Modal Configuration ---" echo "# --- End Modal Configuration ---"

View File

@@ -0,0 +1,379 @@
"""
Reflector GPU backend - audio mixdown
======================================
CPU-intensive audio mixdown service for combining multiple audio tracks.
Uses PyAV filter graph (amix) for high-quality audio mixing.
"""
import os
import tempfile
import time
from fractions import Fraction
import modal
MIXDOWN_TIMEOUT = 900 # 15 minutes
SCALEDOWN_WINDOW = 60 # 1 minute idle before shutdown
app = modal.App("reflector-mixdown")
# CPU-based image (no GPU needed for audio processing)
image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("ffmpeg") # Required by PyAV
.pip_install(
"av==13.1.0", # PyAV for audio processing
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
"fastapi==0.115.12", # API framework
)
)
@app.function(
cpu=4.0, # 4 CPU cores for audio processing
timeout=MIXDOWN_TIMEOUT,
scaledown_window=SCALEDOWN_WINDOW,
secrets=[modal.Secret.from_name("reflector-gpu")],
image=image,
)
@modal.concurrent(max_inputs=10)
@modal.asgi_app()
def web():
import logging
import secrets
import shutil
import av
import requests
from av.audio.resampler import AudioResampler
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Validate API key exists at startup
API_KEY = os.environ.get("REFLECTOR_GPU_APIKEY")
if not API_KEY:
raise RuntimeError("REFLECTOR_GPU_APIKEY not configured in Modal secrets")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
# Use constant-time comparison to prevent timing attacks
if secrets.compare_digest(apikey, API_KEY):
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class MixdownRequest(BaseModel):
track_urls: list[str]
output_url: str
target_sample_rate: int = 48000
expected_duration_sec: float | None = None
class MixdownResponse(BaseModel):
duration_ms: float
tracks_mixed: int
audio_uploaded: bool
def download_track(url: str, temp_dir: str, index: int) -> str:
"""Download track from presigned URL to temp file using streaming."""
logger.info(f"Downloading track {index + 1}")
response = requests.get(url, stream=True, timeout=300)
if response.status_code == 404:
raise HTTPException(status_code=404, detail=f"Track {index} not found")
if response.status_code == 403:
raise HTTPException(
status_code=403, detail=f"Track {index} presigned URL expired"
)
response.raise_for_status()
temp_path = os.path.join(temp_dir, f"track_{index}.webm")
total_bytes = 0
with open(temp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
total_bytes += len(chunk)
logger.info(f"Track {index + 1} downloaded: {total_bytes} bytes")
return temp_path
def mixdown_tracks_modal(
track_paths: list[str],
output_path: str,
target_sample_rate: int,
expected_duration_sec: float | None,
logger,
) -> float:
"""Mix multiple audio tracks using PyAV filter graph.
Args:
track_paths: List of local file paths to audio tracks
output_path: Local path for output MP3 file
target_sample_rate: Sample rate for output (Hz)
expected_duration_sec: Optional fallback duration if container metadata unavailable
logger: Logger instance for progress tracking
Returns:
Duration in milliseconds
"""
logger.info(f"Starting mixdown of {len(track_paths)} tracks")
# Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
graph = av.filter.Graph()
inputs = []
for idx in range(len(track_paths)):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
# Connect inputs to mixer (no delays for Modal implementation)
for idx, in_ctx in enumerate(inputs):
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# Open all containers
containers = []
try:
for i, path in enumerate(track_paths):
try:
c = av.open(path)
containers.append(c)
except Exception as e:
logger.warning(
f"Failed to open container {i}: {e}",
)
if not containers:
raise ValueError("Could not open any track containers")
# Calculate total duration for progress reporting
max_duration_sec = 0.0
for c in containers:
if c.duration is not None:
dur_sec = c.duration / av.time_base
max_duration_sec = max(max_duration_sec, dur_sec)
if max_duration_sec == 0.0 and expected_duration_sec:
max_duration_sec = expected_duration_sec
# Setup output container
out_container = av.open(output_path, "w", format="mp3")
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
current_max_time = 0.0
last_log_time = time.monotonic()
start_time = time.monotonic()
total_duration = 0
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None) # Signal end of stream
continue
if frame.sample_rate != target_sample_rate:
continue
# Progress logging (every 5 seconds)
if frame.time is not None:
current_max_time = max(current_max_time, frame.time)
now = time.monotonic()
if now - last_log_time >= 5.0:
elapsed = now - start_time
if max_duration_sec > 0:
progress_pct = min(
100.0, (current_max_time / max_duration_sec) * 100
)
logger.info(
f"Mixdown progress: {progress_pct:.1f}% @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
)
else:
logger.info(
f"Mixdown progress: @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
)
last_log_time = now
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
# Pull mixed frames from sink and encode
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
# Encode and mux
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush remaining frames from filter graph
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush encoder
for packet in out_stream.encode():
out_container.mux(packet)
total_duration += packet.duration
# Calculate duration in milliseconds
if total_duration > 0:
# Use the same calculation as AudioFileWriterProcessor
duration_ms = round(
float(total_duration * out_stream.time_base * 1000), 2
)
else:
duration_ms = 0.0
out_container.close()
logger.info(f"Mixdown complete: duration={duration_ms}ms")
finally:
# Cleanup all containers
for c in containers:
if c is not None:
try:
c.close()
except Exception:
pass
return duration_ms
@app.post("/v1/audio/mixdown", dependencies=[Depends(apikey_auth)])
def mixdown(request: MixdownRequest) -> MixdownResponse:
"""Mix multiple audio tracks into a single MP3 file.
Tracks are downloaded from presigned S3 URLs, mixed using PyAV,
and uploaded to a presigned S3 PUT URL.
"""
if not request.track_urls:
raise HTTPException(status_code=400, detail="No track URLs provided")
logger.info(f"Mixdown request: {len(request.track_urls)} tracks")
temp_dir = tempfile.mkdtemp()
temp_files = []
output_mp3_path = None
try:
# Download all tracks
for i, url in enumerate(request.track_urls):
temp_path = download_track(url, temp_dir, i)
temp_files.append(temp_path)
# Mix tracks
output_mp3_path = os.path.join(temp_dir, "mixed.mp3")
duration_ms = mixdown_tracks_modal(
temp_files,
output_mp3_path,
request.target_sample_rate,
request.expected_duration_sec,
logger,
)
# Upload result to S3
logger.info("Uploading result to S3")
file_size = os.path.getsize(output_mp3_path)
with open(output_mp3_path, "rb") as f:
upload_response = requests.put(
request.output_url, data=f, timeout=300
)
if upload_response.status_code == 403:
raise HTTPException(
status_code=403, detail="Output presigned URL expired"
)
upload_response.raise_for_status()
logger.info(f"Upload complete: {file_size} bytes")
return MixdownResponse(
duration_ms=duration_ms,
tracks_mixed=len(request.track_urls),
audio_uploaded=True,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Mixdown failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Mixdown failed: {str(e)}")
finally:
# Cleanup temp files
for temp_path in temp_files:
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to cleanup temp file {temp_path}: {e}")
if output_mp3_path and os.path.exists(output_mp3_path):
try:
os.unlink(output_mp3_path)
except Exception as e:
logger.warning(f"Failed to cleanup output file {output_mp3_path}: {e}")
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.warning(f"Failed to cleanup temp directory {temp_dir}: {e}")
return app

View File

@@ -1,44 +0,0 @@
"""replace_use_hatchet_with_use_celery
Revision ID: 80beb1ea3269
Revises: bd3a729bb379
Create Date: 2026-01-20 16:26:25.555869
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "80beb1ea3269"
down_revision: Union[str, None] = "bd3a729bb379"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"use_celery",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
)
)
batch_op.drop_column("use_hatchet")
def downgrade() -> None:
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"use_hatchet",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
)
)
batch_op.drop_column("use_celery")

View File

@@ -58,7 +58,7 @@ rooms = sqlalchemy.Table(
nullable=False, nullable=False,
), ),
sqlalchemy.Column( sqlalchemy.Column(
"use_celery", "use_hatchet",
sqlalchemy.Boolean, sqlalchemy.Boolean,
nullable=False, nullable=False,
server_default=false(), server_default=false(),
@@ -97,7 +97,7 @@ class Room(BaseModel):
ics_last_sync: datetime | None = None ics_last_sync: datetime | None = None
ics_last_etag: str | None = None ics_last_etag: str | None = None
platform: Platform = Field(default_factory=lambda: settings.DEFAULT_VIDEO_PLATFORM) platform: Platform = Field(default_factory=lambda: settings.DEFAULT_VIDEO_PLATFORM)
use_celery: bool = False use_hatchet: bool = False
skip_consent: bool = False skip_consent: bool = False

View File

@@ -12,9 +12,14 @@ from reflector.hatchet.workflows.daily_multitrack_pipeline import (
daily_multitrack_pipeline, daily_multitrack_pipeline,
) )
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings
def main(): def main():
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting CPU workers")
return
hatchet = HatchetClientManager.get_client() hatchet = HatchetClientManager.get_client()
logger.info( logger.info(

View File

@@ -11,6 +11,7 @@ from reflector.hatchet.workflows.subject_processing import subject_workflow
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
from reflector.hatchet.workflows.track_processing import track_workflow from reflector.hatchet.workflows.track_processing import track_workflow
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings
SLOTS = 10 SLOTS = 10
WORKER_NAME = "llm-worker-pool" WORKER_NAME = "llm-worker-pool"
@@ -18,6 +19,10 @@ POOL = "llm-io"
def main(): def main():
if not settings.HATCHET_ENABLED:
logger.error("HATCHET_ENABLED is False, not starting LLM workers")
return
hatchet = HatchetClientManager.get_client() hatchet = HatchetClientManager.get_client()
logger.info( logger.info(

View File

@@ -489,7 +489,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
) )
@with_error_handling(TaskName.MIXDOWN_TRACKS) @with_error_handling(TaskName.MIXDOWN_TRACKS)
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
"""Mix all padded tracks into single audio file using PyAV (same as Celery).""" """Mix all padded tracks into single audio file using PyAV or Modal backend."""
ctx.log("mixdown_tracks: mixing padded tracks into single audio file") ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
track_result = ctx.task_output(process_tracks) track_result = ctx.task_output(process_tracks)
@@ -513,7 +513,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
storage = _spawn_storage() storage = _spawn_storage()
# Presign URLs on demand (avoids stale URLs on workflow replay) # Presign URLs for padded tracks (same expiration for both backends)
padded_urls = [] padded_urls = []
for track_info in padded_tracks: for track_info in padded_tracks:
if track_info.key: if track_info.key:
@@ -534,33 +534,104 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
logger.error("Mixdown failed - no decodable audio frames found") logger.error("Mixdown failed - no decodable audio frames found")
raise ValueError("No decodable audio frames in any track") raise ValueError("No decodable audio frames in any track")
output_path = tempfile.mktemp(suffix=".mp3") output_key = f"{input.transcript_id}/audio.mp3"
duration_ms_callback_capture_container = [0.0]
async def capture_duration(d): # Conditional: Modal or local backend
duration_ms_callback_capture_container[0] = d if settings.MIXDOWN_BACKEND == "modal":
ctx.log("mixdown_tracks: using Modal backend")
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration) # Presign PUT URL for output (Modal will upload directly)
output_url = await storage.get_file_url(
output_key,
operation="put_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
)
await mixdown_tracks_pyav( from reflector.processors.audio_mixdown_modal import ( # noqa: PLC0415
valid_urls, AudioMixdownModalProcessor,
writer, )
target_sample_rate,
offsets_seconds=None,
logger=logger,
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
expected_duration_sec=recording_duration if recording_duration > 0 else None,
)
await writer.flush()
file_size = Path(output_path).stat().st_size try:
storage_path = f"{input.transcript_id}/audio.mp3" processor = AudioMixdownModalProcessor()
result = await processor.mixdown(
track_urls=valid_urls,
output_url=output_url,
target_sample_rate=target_sample_rate,
expected_duration_sec=recording_duration
if recording_duration > 0
else None,
)
duration_ms = result.duration_ms
tracks_mixed = result.tracks_mixed
with open(output_path, "rb") as mixed_file: ctx.log(
await storage.put_file(storage_path, mixed_file) f"mixdown_tracks: Modal returned duration={duration_ms}ms, tracks={tracks_mixed}"
)
except httpx.HTTPStatusError as e:
error_detail = e.response.text if hasattr(e.response, "text") else str(e)
logger.error(
"[Hatchet] Modal mixdown HTTP error",
transcript_id=input.transcript_id,
status_code=e.response.status_code if hasattr(e, "response") else None,
error=error_detail,
)
raise RuntimeError(
f"Modal mixdown failed with HTTP {e.response.status_code}: {error_detail}"
)
except httpx.TimeoutException:
logger.error(
"[Hatchet] Modal mixdown timeout",
transcript_id=input.transcript_id,
timeout=settings.MIXDOWN_TIMEOUT,
)
raise RuntimeError(
f"Modal mixdown timeout after {settings.MIXDOWN_TIMEOUT}s"
)
except ValueError as e:
logger.error(
"[Hatchet] Modal mixdown validation error",
transcript_id=input.transcript_id,
error=str(e),
)
raise
else:
ctx.log("mixdown_tracks: using local backend")
Path(output_path).unlink(missing_ok=True) # Existing local implementation
output_path = tempfile.mktemp(suffix=".mp3")
duration_ms_callback_capture_container = [0.0]
async def capture_duration(d):
duration_ms_callback_capture_container[0] = d
writer = AudioFileWriterProcessor(
path=output_path, on_duration=capture_duration
)
await mixdown_tracks_pyav(
valid_urls,
writer,
target_sample_rate,
offsets_seconds=None,
logger=logger,
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
expected_duration_sec=recording_duration
if recording_duration > 0
else None,
)
await writer.flush()
file_size = Path(output_path).stat().st_size
with open(output_path, "rb") as mixed_file:
await storage.put_file(output_key, mixed_file)
Path(output_path).unlink(missing_ok=True)
duration_ms = duration_ms_callback_capture_container[0]
tracks_mixed = len(valid_urls)
ctx.log(f"mixdown_tracks: local mixdown uploaded {file_size} bytes")
# Update DB (same for both backends)
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
@@ -570,12 +641,12 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
transcript, {"audio_location": "storage"} transcript, {"audio_location": "storage"}
) )
ctx.log(f"mixdown_tracks complete: uploaded {file_size} bytes to {storage_path}") ctx.log(f"mixdown_tracks complete: uploaded to {output_key}")
return MixdownResult( return MixdownResult(
audio_key=storage_path, audio_key=output_key,
duration=duration_ms_callback_capture_container[0], duration=duration_ms,
tracks_mixed=len(valid_urls), tracks_mixed=tracks_mixed,
) )

View File

@@ -0,0 +1,89 @@
"""
Modal.com backend for audio mixdown.
Uses Modal's CPU containers to offload audio mixing from Hatchet workers.
Communicates via presigned S3 URLs for both input and output.
"""
import httpx
from pydantic import BaseModel
from reflector.settings import settings
class MixdownResponse(BaseModel):
"""Response from Modal mixdown endpoint."""
duration_ms: float
tracks_mixed: int
audio_uploaded: bool
class AudioMixdownModalProcessor:
"""Audio mixdown processor using Modal.com CPU backend.
Sends track URLs (presigned GET) and output URL (presigned PUT) to Modal.
Modal handles download, mixdown via PyAV, and upload.
"""
def __init__(self, modal_api_key: str | None = None):
if not settings.MIXDOWN_URL:
raise ValueError("MIXDOWN_URL required to use AudioMixdownModalProcessor")
self.mixdown_url = settings.MIXDOWN_URL + "/v1"
self.timeout = settings.MIXDOWN_TIMEOUT
self.modal_api_key = modal_api_key or settings.MIXDOWN_MODAL_API_KEY
if not self.modal_api_key:
raise ValueError(
"MIXDOWN_MODAL_API_KEY required to use AudioMixdownModalProcessor"
)
async def mixdown(
self,
track_urls: list[str],
output_url: str,
target_sample_rate: int,
expected_duration_sec: float | None = None,
) -> MixdownResponse:
"""Mix multiple audio tracks via Modal backend.
Args:
track_urls: List of presigned GET URLs for audio tracks (non-empty)
output_url: Presigned PUT URL for output MP3
target_sample_rate: Sample rate for output (Hz, must be positive)
expected_duration_sec: Optional fallback duration if container metadata unavailable
Returns:
MixdownResponse with duration_ms, tracks_mixed, audio_uploaded
Raises:
ValueError: If track_urls is empty or target_sample_rate invalid
httpx.HTTPStatusError: On HTTP errors (404, 403, 500, etc.)
httpx.TimeoutException: On timeout
"""
# Validate inputs
if not track_urls:
raise ValueError("track_urls cannot be empty")
if target_sample_rate <= 0:
raise ValueError(
f"target_sample_rate must be positive, got {target_sample_rate}"
)
if expected_duration_sec is not None and expected_duration_sec < 0:
raise ValueError(
f"expected_duration_sec cannot be negative, got {expected_duration_sec}"
)
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.mixdown_url}/audio/mixdown",
headers={"Authorization": f"Bearer {self.modal_api_key}"},
json={
"track_urls": track_urls,
"output_url": output_url,
"target_sample_rate": target_sample_rate,
"expected_duration_sec": expected_duration_sec,
},
)
response.raise_for_status()
return MixdownResponse(**response.json())

View File

@@ -319,6 +319,21 @@ class ICSSyncService:
calendar = self.fetch_service.parse_ics(ics_content) calendar = self.fetch_service.parse_ics(ics_content)
content_hash = hashlib.md5(ics_content.encode()).hexdigest() content_hash = hashlib.md5(ics_content.encode()).hexdigest()
if room.ics_last_etag == content_hash:
logger.info("No changes in ICS for room", room_id=room.id)
room_url = f"{settings.UI_BASE_URL}/{room.name}"
events, total_events = self.fetch_service.extract_room_events(
calendar, room.name, room_url
)
return {
"status": SyncStatus.UNCHANGED,
"hash": content_hash,
"events_found": len(events),
"total_events": total_events,
"events_created": 0,
"events_updated": 0,
"events_deleted": 0,
}
# Extract matching events # Extract matching events
room_url = f"{settings.UI_BASE_URL}/{room.name}" room_url = f"{settings.UI_BASE_URL}/{room.name}"
@@ -356,44 +371,6 @@ class ICSSyncService:
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
return time_since_sync.total_seconds() >= room.ics_fetch_interval return time_since_sync.total_seconds() >= room.ics_fetch_interval
def _event_data_changed(self, existing: CalendarEvent, new_data: EventData) -> bool:
"""Check if event data has changed by comparing relevant fields.
IMPORTANT: When adding fields to CalendarEvent/EventData, update this method
and the _COMPARED_FIELDS set below for runtime validation.
"""
# Fields that come from ICS and should trigger updates when changed
_COMPARED_FIELDS = {
"title",
"description",
"start_time",
"end_time",
"location",
"attendees",
"ics_raw_data",
}
# Runtime exhaustiveness check: ensure we're comparing all EventData fields
event_data_fields = set(EventData.__annotations__.keys()) - {"ics_uid"}
if event_data_fields != _COMPARED_FIELDS:
missing = event_data_fields - _COMPARED_FIELDS
extra = _COMPARED_FIELDS - event_data_fields
raise RuntimeError(
f"_event_data_changed() field mismatch: "
f"missing={missing}, extra={extra}. "
f"Update the comparison logic when adding/removing fields."
)
return (
existing.title != new_data["title"]
or existing.description != new_data["description"]
or existing.start_time != new_data["start_time"]
or existing.end_time != new_data["end_time"]
or existing.location != new_data["location"]
or existing.attendees != new_data["attendees"]
or existing.ics_raw_data != new_data["ics_raw_data"]
)
async def _sync_events_to_database( async def _sync_events_to_database(
self, room_id: str, events: list[EventData] self, room_id: str, events: list[EventData]
) -> SyncStats: ) -> SyncStats:
@@ -409,14 +386,11 @@ class ICSSyncService:
) )
if existing: if existing:
# Only count as updated if data actually changed updated += 1
if self._event_data_changed(existing, event_data):
updated += 1
await calendar_events_controller.upsert(calendar_event)
else: else:
created += 1 created += 1
await calendar_events_controller.upsert(calendar_event)
await calendar_events_controller.upsert(calendar_event)
current_ics_uids.append(event_data["ics_uid"]) current_ics_uids.append(event_data["ics_uid"])
# Soft delete events that are no longer in calendar # Soft delete events that are no longer in calendar

View File

@@ -23,6 +23,7 @@ from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_multitrack_pipeline import ( from reflector.pipelines.main_multitrack_pipeline import (
task_pipeline_multitrack_process, task_pipeline_multitrack_process,
) )
from reflector.settings import settings
from reflector.utils.string import NonEmptyString from reflector.utils.string import NonEmptyString
@@ -101,8 +102,8 @@ async def validate_transcript_for_processing(
if transcript.locked: if transcript.locked:
return ValidationLocked(detail="Recording is locked") return ValidationLocked(detail="Recording is locked")
# Check if recording is ready for processing # hatchet is idempotent anyways + if it wasn't dispatched successfully
if transcript.status == "idle" and not transcript.workflow_run_id: if transcript.status == "idle" and not settings.HATCHET_ENABLED:
return ValidationNotReady(detail="Recording is not ready for processing") return ValidationNotReady(detail="Recording is not ready for processing")
# Check Celery tasks # Check Celery tasks
@@ -115,8 +116,7 @@ async def validate_transcript_for_processing(
): ):
return ValidationAlreadyScheduled(detail="already running") return ValidationAlreadyScheduled(detail="already running")
# Check Hatchet workflow status if workflow_run_id exists if settings.HATCHET_ENABLED and transcript.workflow_run_id:
if transcript.workflow_run_id:
try: try:
status = await HatchetClientManager.get_workflow_run_status( status = await HatchetClientManager.get_workflow_run_status(
transcript.workflow_run_id transcript.workflow_run_id
@@ -181,16 +181,19 @@ async def dispatch_transcript_processing(
Returns AsyncResult for Celery tasks, None for Hatchet workflows. Returns AsyncResult for Celery tasks, None for Hatchet workflows.
""" """
if isinstance(config, MultitrackProcessingConfig): if isinstance(config, MultitrackProcessingConfig):
use_celery = False # Check if room has use_hatchet=True (overrides env vars)
room_forces_hatchet = False
if config.room_id: if config.room_id:
room = await rooms_controller.get_by_id(config.room_id) room = await rooms_controller.get_by_id(config.room_id)
use_celery = room.use_celery if room else False room_forces_hatchet = room.use_hatchet if room else False
use_hatchet = not use_celery # Start durable workflow if enabled (Hatchet)
# and if room has use_hatchet=True
use_hatchet = settings.HATCHET_ENABLED and room_forces_hatchet
if use_celery: if room_forces_hatchet:
logger.info( logger.info(
"Room uses legacy Celery processing", "Room forces Hatchet workflow",
room_id=config.room_id, room_id=config.room_id,
transcript_id=config.transcript_id, transcript_id=config.transcript_id,
) )

View File

@@ -98,6 +98,17 @@ class Settings(BaseSettings):
# Diarization: local pyannote.audio # Diarization: local pyannote.audio
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
# Audio Mixdown
# backends:
# - local: in-process PyAV mixdown (runs in same process as Hatchet worker)
# - modal: HTTP API client to Modal.com CPU container
MIXDOWN_BACKEND: str = "local"
MIXDOWN_URL: str | None = None
MIXDOWN_TIMEOUT: int = 900 # 15 minutes
# Mixdown: modal backend
MIXDOWN_MODAL_API_KEY: str | None = None
# Sentry # Sentry
SENTRY_DSN: str | None = None SENTRY_DSN: str | None = None
@@ -158,10 +169,19 @@ class Settings(BaseSettings):
ZULIP_API_KEY: str | None = None ZULIP_API_KEY: str | None = None
ZULIP_BOT_EMAIL: str | None = None ZULIP_BOT_EMAIL: str | None = None
# Hatchet workflow orchestration (always enabled for multitrack processing) # Durable workflow orchestration
# Provider: "hatchet" (or "none" to disable)
DURABLE_WORKFLOW_PROVIDER: str = "none"
# Hatchet workflow orchestration
HATCHET_CLIENT_TOKEN: str | None = None HATCHET_CLIENT_TOKEN: str | None = None
HATCHET_CLIENT_TLS_STRATEGY: str = "none" # none, tls, mtls HATCHET_CLIENT_TLS_STRATEGY: str = "none" # none, tls, mtls
HATCHET_DEBUG: bool = False HATCHET_DEBUG: bool = False
@property
def HATCHET_ENABLED(self) -> bool:
"""True if Hatchet is the active provider."""
return self.DURABLE_WORKFLOW_PROVIDER == "hatchet"
settings = Settings() settings = Settings()

View File

@@ -287,12 +287,11 @@ async def _process_multitrack_recording_inner(
room_id=room.id, room_id=room.id,
) )
use_celery = room and room.use_celery use_hatchet = settings.HATCHET_ENABLED and room and room.use_hatchet
use_hatchet = not use_celery
if use_celery: if room and room.use_hatchet and not settings.HATCHET_ENABLED:
logger.info( logger.info(
"Room uses legacy Celery processing", "Room forces Hatchet workflow",
room_id=room.id, room_id=room.id,
transcript_id=transcript.id, transcript_id=transcript.id,
) )
@@ -811,6 +810,7 @@ async def reprocess_failed_daily_recordings():
) )
continue continue
# Fetch room to check use_hatchet flag
room = None room = None
if meeting.room_id: if meeting.room_id:
room = await rooms_controller.get_by_id(meeting.room_id) room = await rooms_controller.get_by_id(meeting.room_id)
@@ -834,10 +834,10 @@ async def reprocess_failed_daily_recordings():
) )
continue continue
use_celery = room and room.use_celery use_hatchet = settings.HATCHET_ENABLED and room and room.use_hatchet
use_hatchet = not use_celery
if use_hatchet: if use_hatchet:
# Hatchet requires a transcript for workflow_run_id tracking
if not transcript: if not transcript:
logger.warning( logger.warning(
"No transcript for Hatchet reprocessing, skipping", "No transcript for Hatchet reprocessing, skipping",

View File

@@ -2,9 +2,10 @@
Tests for Hatchet workflow dispatch and routing logic. Tests for Hatchet workflow dispatch and routing logic.
These tests verify: These tests verify:
1. Hatchet workflow validation and replay logic 1. Routing to Hatchet when HATCHET_ENABLED=True
2. Force flag to cancel and restart workflows 2. Replay logic for failed workflows
3. Validation prevents concurrent workflows 3. Force flag to cancel and restart
4. Validation prevents concurrent workflows
""" """
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@@ -33,22 +34,25 @@ async def test_hatchet_validation_blocks_running_workflow():
workflow_run_id="running-workflow-123", workflow_run_id="running-workflow-123",
) )
with patch( with patch("reflector.services.transcript_process.settings") as mock_settings:
"reflector.services.transcript_process.HatchetClientManager" mock_settings.HATCHET_ENABLED = True
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.RUNNING
)
with patch( with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active" "reflector.services.transcript_process.HatchetClientManager"
) as mock_celery_check: ) as mock_hatchet:
mock_celery_check.return_value = False mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.RUNNING
)
result = await validate_transcript_for_processing(mock_transcript) with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationAlreadyScheduled) result = await validate_transcript_for_processing(mock_transcript)
assert "running" in result.detail.lower()
assert isinstance(result, ValidationAlreadyScheduled)
assert "running" in result.detail.lower()
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -68,21 +72,24 @@ async def test_hatchet_validation_blocks_queued_workflow():
workflow_run_id="queued-workflow-123", workflow_run_id="queued-workflow-123",
) )
with patch( with patch("reflector.services.transcript_process.settings") as mock_settings:
"reflector.services.transcript_process.HatchetClientManager" mock_settings.HATCHET_ENABLED = True
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.QUEUED
)
with patch( with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active" "reflector.services.transcript_process.HatchetClientManager"
) as mock_celery_check: ) as mock_hatchet:
mock_celery_check.return_value = False mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.QUEUED
)
result = await validate_transcript_for_processing(mock_transcript) with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationAlreadyScheduled) result = await validate_transcript_for_processing(mock_transcript)
assert isinstance(result, ValidationAlreadyScheduled)
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -103,22 +110,25 @@ async def test_hatchet_validation_allows_failed_workflow():
recording_id="test-recording-id", recording_id="test-recording-id",
) )
with patch( with patch("reflector.services.transcript_process.settings") as mock_settings:
"reflector.services.transcript_process.HatchetClientManager" mock_settings.HATCHET_ENABLED = True
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.FAILED
)
with patch( with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active" "reflector.services.transcript_process.HatchetClientManager"
) as mock_celery_check: ) as mock_hatchet:
mock_celery_check.return_value = False mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.FAILED
)
result = await validate_transcript_for_processing(mock_transcript) with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationOk) result = await validate_transcript_for_processing(mock_transcript)
assert result.transcript_id == "test-transcript-id"
assert isinstance(result, ValidationOk)
assert result.transcript_id == "test-transcript-id"
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -139,21 +149,24 @@ async def test_hatchet_validation_allows_completed_workflow():
recording_id="test-recording-id", recording_id="test-recording-id",
) )
with patch( with patch("reflector.services.transcript_process.settings") as mock_settings:
"reflector.services.transcript_process.HatchetClientManager" mock_settings.HATCHET_ENABLED = True
) as mock_hatchet:
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.COMPLETED
)
with patch( with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active" "reflector.services.transcript_process.HatchetClientManager"
) as mock_celery_check: ) as mock_hatchet:
mock_celery_check.return_value = False mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.COMPLETED
)
result = await validate_transcript_for_processing(mock_transcript) with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
assert isinstance(result, ValidationOk) result = await validate_transcript_for_processing(mock_transcript)
assert isinstance(result, ValidationOk)
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -174,23 +187,26 @@ async def test_hatchet_validation_allows_when_status_check_fails():
recording_id="test-recording-id", recording_id="test-recording-id",
) )
with patch( with patch("reflector.services.transcript_process.settings") as mock_settings:
"reflector.services.transcript_process.HatchetClientManager" mock_settings.HATCHET_ENABLED = True
) as mock_hatchet:
# Status check fails (workflow might be deleted)
mock_hatchet.get_workflow_run_status = AsyncMock(
side_effect=ApiException("Workflow not found")
)
with patch( with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active" "reflector.services.transcript_process.HatchetClientManager"
) as mock_celery_check: ) as mock_hatchet:
mock_celery_check.return_value = False # Status check fails (workflow might be deleted)
mock_hatchet.get_workflow_run_status = AsyncMock(
side_effect=ApiException("Workflow not found")
)
result = await validate_transcript_for_processing(mock_transcript) with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
# Should allow processing when we can't get status result = await validate_transcript_for_processing(mock_transcript)
assert isinstance(result, ValidationOk)
# Should allow processing when we can't get status
assert isinstance(result, ValidationOk)
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -211,11 +227,47 @@ async def test_hatchet_validation_skipped_when_no_workflow_id():
recording_id="test-recording-id", recording_id="test-recording-id",
) )
with patch( with patch("reflector.services.transcript_process.settings") as mock_settings:
"reflector.services.transcript_process.HatchetClientManager" mock_settings.HATCHET_ENABLED = True
) as mock_hatchet:
# Should not be called with patch(
mock_hatchet.get_workflow_run_status = AsyncMock() "reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet:
# Should not be called
mock_hatchet.get_workflow_run_status = AsyncMock()
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery_check:
mock_celery_check.return_value = False
result = await validate_transcript_for_processing(mock_transcript)
# Should not check Hatchet status
mock_hatchet.get_workflow_run_status.assert_not_called()
assert isinstance(result, ValidationOk)
@pytest.mark.usefixtures("setup_database")
@pytest.mark.asyncio
async def test_hatchet_validation_skipped_when_disabled():
"""Test that Hatchet validation is skipped when HATCHET_ENABLED is False."""
from reflector.services.transcript_process import (
ValidationOk,
validate_transcript_for_processing,
)
mock_transcript = Transcript(
id="test-transcript-id",
name="Test",
status="uploaded",
source_kind="room",
workflow_run_id="some-workflow-123",
recording_id="test-recording-id",
)
with patch("reflector.services.transcript_process.settings") as mock_settings:
mock_settings.HATCHET_ENABLED = False # Hatchet disabled
with patch( with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active" "reflector.services.transcript_process.task_is_scheduled_or_active"
@@ -224,8 +276,7 @@ async def test_hatchet_validation_skipped_when_no_workflow_id():
result = await validate_transcript_for_processing(mock_transcript) result = await validate_transcript_for_processing(mock_transcript)
# Should not check Hatchet status # Should not check Hatchet at all
mock_hatchet.get_workflow_run_status.assert_not_called()
assert isinstance(result, ValidationOk) assert isinstance(result, ValidationOk)

View File

@@ -189,17 +189,14 @@ async def test_ics_sync_service_sync_room_calendar():
assert events[0].ics_uid == "sync-event-1" assert events[0].ics_uid == "sync-event-1"
assert events[0].title == "Sync Test Meeting" assert events[0].title == "Sync Test Meeting"
# Second sync with same content (calendar unchanged, but sync always runs) # Second sync with same content (should be unchanged)
# Refresh room to get updated etag and force sync by setting old sync time # Refresh room to get updated etag and force sync by setting old sync time
room = await rooms_controller.get_by_id(room.id) room = await rooms_controller.get_by_id(room.id)
await rooms_controller.update( await rooms_controller.update(
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)} room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
) )
result = await sync_service.sync_room_calendar(room) result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success" assert result["status"] == "unchanged"
assert result["events_created"] == 0
assert result["events_updated"] == 0
assert result["events_deleted"] == 0
# Third sync with updated event # Third sync with updated event
event["summary"] = "Updated Meeting Title" event["summary"] = "Updated Meeting Title"
@@ -291,43 +288,3 @@ async def test_ics_sync_service_error_handling():
result = await sync_service.sync_room_calendar(room) result = await sync_service.sync_room_calendar(room)
assert result["status"] == "error" assert result["status"] == "error"
assert "Network error" in result["error"] assert "Network error" in result["error"]
@pytest.mark.asyncio
async def test_event_data_changed_exhaustiveness():
"""Test that _event_data_changed compares all EventData fields (except ics_uid).
This test ensures programmers don't forget to update the comparison logic
when adding new fields to EventData/CalendarEvent.
"""
from reflector.services.ics_sync import EventData
sync_service = ICSSyncService()
from reflector.db.calendar_events import CalendarEvent
now = datetime.now(timezone.utc)
event_data: EventData = {
"ics_uid": "test-123",
"title": "Test",
"description": "Desc",
"location": "Loc",
"start_time": now,
"end_time": now + timedelta(hours=1),
"attendees": [],
"ics_raw_data": "raw",
}
existing = CalendarEvent(
room_id="room1",
**event_data,
)
# Will raise RuntimeError if fields are missing from comparison
result = sync_service._event_data_changed(existing, event_data)
assert result is False
modified_data = event_data.copy()
modified_data["title"] = "Changed Title"
result = sync_service._event_data_changed(existing, modified_data)
assert result is True

View File

@@ -162,24 +162,9 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client):
from datetime import datetime, timezone from datetime import datetime, timezone
from reflector.db.recordings import Recording, recordings_controller from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
room = await rooms_controller.add( # Create transcript with Daily.co multitrack recording
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
# Force Celery backend for test
await rooms_controller.update(room, {"use_celery": True})
transcript = await transcripts_controller.add( transcript = await transcripts_controller.add(
"", "",
source_kind="room", source_kind="room",
@@ -187,7 +172,6 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client):
target_language="en", target_language="en",
user_id="test-user", user_id="test-user",
share_mode="public", share_mode="public",
room_id=room.id,
) )
track_keys = [ track_keys = [