Add Modal backend for audio mixdown

This commit is contained in:
Igor Loskutov
2026-01-21 17:06:17 -05:00
parent c8743fdf1c
commit e1b790c5a8
5 changed files with 589 additions and 26 deletions

View File

@@ -131,6 +131,15 @@ if [ -z "$DIARIZER_URL" ]; then
fi fi
echo " -> $DIARIZER_URL" echo " -> $DIARIZER_URL"
echo ""
echo "Deploying mixdown (CPU audio processing)..."
MIXDOWN_URL=$(modal deploy reflector_mixdown.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
if [ -z "$MIXDOWN_URL" ]; then
echo "Error: Failed to deploy mixdown. Check Modal dashboard for details."
exit 1
fi
echo " -> $MIXDOWN_URL"
# --- Output Configuration --- # --- Output Configuration ---
echo "" echo ""
echo "==========================================" echo "=========================================="
@@ -147,4 +156,8 @@ echo ""
echo "DIARIZATION_BACKEND=modal" echo "DIARIZATION_BACKEND=modal"
echo "DIARIZATION_URL=$DIARIZER_URL" echo "DIARIZATION_URL=$DIARIZER_URL"
echo "DIARIZATION_MODAL_API_KEY=$API_KEY" echo "DIARIZATION_MODAL_API_KEY=$API_KEY"
echo ""
echo "MIXDOWN_BACKEND=modal"
echo "MIXDOWN_URL=$MIXDOWN_URL"
echo "MIXDOWN_MODAL_API_KEY=$API_KEY"
echo "# --- End Modal Configuration ---" echo "# --- End Modal Configuration ---"

View File

@@ -0,0 +1,379 @@
"""
Reflector GPU backend - audio mixdown
======================================
CPU-intensive audio mixdown service for combining multiple audio tracks.
Uses PyAV filter graph (amix) for high-quality audio mixing.
"""
import os
import tempfile
import time
from fractions import Fraction
import modal
MIXDOWN_TIMEOUT = 900 # 15 minutes
SCALEDOWN_WINDOW = 60 # 1 minute idle before shutdown
app = modal.App("reflector-mixdown")
# CPU-based image (no GPU needed for audio processing)
image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("ffmpeg") # Required by PyAV
.pip_install(
"av==13.1.0", # PyAV for audio processing
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
"fastapi==0.115.12", # API framework
)
)
@app.function(
cpu=4.0, # 4 CPU cores for audio processing
timeout=MIXDOWN_TIMEOUT,
scaledown_window=SCALEDOWN_WINDOW,
secrets=[modal.Secret.from_name("reflector-gpu")],
image=image,
)
@modal.concurrent(max_inputs=10)
@modal.asgi_app()
def web():
import logging
import secrets
import shutil
import av
import requests
from av.audio.resampler import AudioResampler
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Validate API key exists at startup
API_KEY = os.environ.get("REFLECTOR_GPU_APIKEY")
if not API_KEY:
raise RuntimeError("REFLECTOR_GPU_APIKEY not configured in Modal secrets")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
# Use constant-time comparison to prevent timing attacks
if secrets.compare_digest(apikey, API_KEY):
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class MixdownRequest(BaseModel):
track_urls: list[str]
output_url: str
target_sample_rate: int = 48000
expected_duration_sec: float | None = None
class MixdownResponse(BaseModel):
duration_ms: float
tracks_mixed: int
audio_uploaded: bool
def download_track(url: str, temp_dir: str, index: int) -> str:
"""Download track from presigned URL to temp file using streaming."""
logger.info(f"Downloading track {index + 1}")
response = requests.get(url, stream=True, timeout=300)
if response.status_code == 404:
raise HTTPException(status_code=404, detail=f"Track {index} not found")
if response.status_code == 403:
raise HTTPException(
status_code=403, detail=f"Track {index} presigned URL expired"
)
response.raise_for_status()
temp_path = os.path.join(temp_dir, f"track_{index}.webm")
total_bytes = 0
with open(temp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
total_bytes += len(chunk)
logger.info(f"Track {index + 1} downloaded: {total_bytes} bytes")
return temp_path
def mixdown_tracks_modal(
track_paths: list[str],
output_path: str,
target_sample_rate: int,
expected_duration_sec: float | None,
logger,
) -> float:
"""Mix multiple audio tracks using PyAV filter graph.
Args:
track_paths: List of local file paths to audio tracks
output_path: Local path for output MP3 file
target_sample_rate: Sample rate for output (Hz)
expected_duration_sec: Optional fallback duration if container metadata unavailable
logger: Logger instance for progress tracking
Returns:
Duration in milliseconds
"""
logger.info(f"Starting mixdown of {len(track_paths)} tracks")
# Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
graph = av.filter.Graph()
inputs = []
for idx in range(len(track_paths)):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
# Connect inputs to mixer (no delays for Modal implementation)
for idx, in_ctx in enumerate(inputs):
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# Open all containers
containers = []
try:
for i, path in enumerate(track_paths):
try:
c = av.open(path)
containers.append(c)
except Exception as e:
logger.warning(
f"Failed to open container {i}: {e}",
)
if not containers:
raise ValueError("Could not open any track containers")
# Calculate total duration for progress reporting
max_duration_sec = 0.0
for c in containers:
if c.duration is not None:
dur_sec = c.duration / av.time_base
max_duration_sec = max(max_duration_sec, dur_sec)
if max_duration_sec == 0.0 and expected_duration_sec:
max_duration_sec = expected_duration_sec
# Setup output container
out_container = av.open(output_path, "w", format="mp3")
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
current_max_time = 0.0
last_log_time = time.monotonic()
start_time = time.monotonic()
total_duration = 0
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None) # Signal end of stream
continue
if frame.sample_rate != target_sample_rate:
continue
# Progress logging (every 5 seconds)
if frame.time is not None:
current_max_time = max(current_max_time, frame.time)
now = time.monotonic()
if now - last_log_time >= 5.0:
elapsed = now - start_time
if max_duration_sec > 0:
progress_pct = min(
100.0, (current_max_time / max_duration_sec) * 100
)
logger.info(
f"Mixdown progress: {progress_pct:.1f}% @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
)
else:
logger.info(
f"Mixdown progress: @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
)
last_log_time = now
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
# Pull mixed frames from sink and encode
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
# Encode and mux
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush remaining frames from filter graph
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush encoder
for packet in out_stream.encode():
out_container.mux(packet)
total_duration += packet.duration
# Calculate duration in milliseconds
if total_duration > 0:
# Use the same calculation as AudioFileWriterProcessor
duration_ms = round(
float(total_duration * out_stream.time_base * 1000), 2
)
else:
duration_ms = 0.0
out_container.close()
logger.info(f"Mixdown complete: duration={duration_ms}ms")
finally:
# Cleanup all containers
for c in containers:
if c is not None:
try:
c.close()
except Exception:
pass
return duration_ms
@app.post("/v1/audio/mixdown", dependencies=[Depends(apikey_auth)])
def mixdown(request: MixdownRequest) -> MixdownResponse:
"""Mix multiple audio tracks into a single MP3 file.
Tracks are downloaded from presigned S3 URLs, mixed using PyAV,
and uploaded to a presigned S3 PUT URL.
"""
if not request.track_urls:
raise HTTPException(status_code=400, detail="No track URLs provided")
logger.info(f"Mixdown request: {len(request.track_urls)} tracks")
temp_dir = tempfile.mkdtemp()
temp_files = []
output_mp3_path = None
try:
# Download all tracks
for i, url in enumerate(request.track_urls):
temp_path = download_track(url, temp_dir, i)
temp_files.append(temp_path)
# Mix tracks
output_mp3_path = os.path.join(temp_dir, "mixed.mp3")
duration_ms = mixdown_tracks_modal(
temp_files,
output_mp3_path,
request.target_sample_rate,
request.expected_duration_sec,
logger,
)
# Upload result to S3
logger.info("Uploading result to S3")
file_size = os.path.getsize(output_mp3_path)
with open(output_mp3_path, "rb") as f:
upload_response = requests.put(
request.output_url, data=f, timeout=300
)
if upload_response.status_code == 403:
raise HTTPException(
status_code=403, detail="Output presigned URL expired"
)
upload_response.raise_for_status()
logger.info(f"Upload complete: {file_size} bytes")
return MixdownResponse(
duration_ms=duration_ms,
tracks_mixed=len(request.track_urls),
audio_uploaded=True,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Mixdown failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Mixdown failed: {str(e)}")
finally:
# Cleanup temp files
for temp_path in temp_files:
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to cleanup temp file {temp_path}: {e}")
if output_mp3_path and os.path.exists(output_mp3_path):
try:
os.unlink(output_mp3_path)
except Exception as e:
logger.warning(f"Failed to cleanup output file {output_mp3_path}: {e}")
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.warning(f"Failed to cleanup temp directory {temp_dir}: {e}")
return app

View File

@@ -489,7 +489,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
) )
@with_error_handling(TaskName.MIXDOWN_TRACKS) @with_error_handling(TaskName.MIXDOWN_TRACKS)
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
"""Mix all padded tracks into single audio file using PyAV (same as Celery).""" """Mix all padded tracks into single audio file using PyAV or Modal backend."""
ctx.log("mixdown_tracks: mixing padded tracks into single audio file") ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
track_result = ctx.task_output(process_tracks) track_result = ctx.task_output(process_tracks)
@@ -513,7 +513,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
storage = _spawn_storage() storage = _spawn_storage()
# Presign URLs on demand (avoids stale URLs on workflow replay) # Presign URLs for padded tracks (same expiration for both backends)
padded_urls = [] padded_urls = []
for track_info in padded_tracks: for track_info in padded_tracks:
if track_info.key: if track_info.key:
@@ -534,13 +534,79 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
logger.error("Mixdown failed - no decodable audio frames found") logger.error("Mixdown failed - no decodable audio frames found")
raise ValueError("No decodable audio frames in any track") raise ValueError("No decodable audio frames in any track")
output_key = f"{input.transcript_id}/audio.mp3"
# Conditional: Modal or local backend
if settings.MIXDOWN_BACKEND == "modal":
ctx.log("mixdown_tracks: using Modal backend")
# Presign PUT URL for output (Modal will upload directly)
output_url = await storage.get_file_url(
output_key,
operation="put_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
)
from reflector.processors.audio_mixdown_modal import ( # noqa: PLC0415
AudioMixdownModalProcessor,
)
try:
processor = AudioMixdownModalProcessor()
result = await processor.mixdown(
track_urls=valid_urls,
output_url=output_url,
target_sample_rate=target_sample_rate,
expected_duration_sec=recording_duration
if recording_duration > 0
else None,
)
duration_ms = result.duration_ms
tracks_mixed = result.tracks_mixed
ctx.log(
f"mixdown_tracks: Modal returned duration={duration_ms}ms, tracks={tracks_mixed}"
)
except httpx.HTTPStatusError as e:
error_detail = e.response.text if hasattr(e.response, "text") else str(e)
logger.error(
"[Hatchet] Modal mixdown HTTP error",
transcript_id=input.transcript_id,
status_code=e.response.status_code if hasattr(e, "response") else None,
error=error_detail,
)
raise RuntimeError(
f"Modal mixdown failed with HTTP {e.response.status_code}: {error_detail}"
)
except httpx.TimeoutException:
logger.error(
"[Hatchet] Modal mixdown timeout",
transcript_id=input.transcript_id,
timeout=settings.MIXDOWN_TIMEOUT,
)
raise RuntimeError(
f"Modal mixdown timeout after {settings.MIXDOWN_TIMEOUT}s"
)
except ValueError as e:
logger.error(
"[Hatchet] Modal mixdown validation error",
transcript_id=input.transcript_id,
error=str(e),
)
raise
else:
ctx.log("mixdown_tracks: using local backend")
# Existing local implementation
output_path = tempfile.mktemp(suffix=".mp3") output_path = tempfile.mktemp(suffix=".mp3")
duration_ms_callback_capture_container = [0.0] duration_ms_callback_capture_container = [0.0]
async def capture_duration(d): async def capture_duration(d):
duration_ms_callback_capture_container[0] = d duration_ms_callback_capture_container[0] = d
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration) writer = AudioFileWriterProcessor(
path=output_path, on_duration=capture_duration
)
await mixdown_tracks_pyav( await mixdown_tracks_pyav(
valid_urls, valid_urls,
@@ -549,18 +615,23 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
offsets_seconds=None, offsets_seconds=None,
logger=logger, logger=logger,
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS), progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
expected_duration_sec=recording_duration if recording_duration > 0 else None, expected_duration_sec=recording_duration
if recording_duration > 0
else None,
) )
await writer.flush() await writer.flush()
file_size = Path(output_path).stat().st_size file_size = Path(output_path).stat().st_size
storage_path = f"{input.transcript_id}/audio.mp3"
with open(output_path, "rb") as mixed_file: with open(output_path, "rb") as mixed_file:
await storage.put_file(storage_path, mixed_file) await storage.put_file(output_key, mixed_file)
Path(output_path).unlink(missing_ok=True) Path(output_path).unlink(missing_ok=True)
duration_ms = duration_ms_callback_capture_container[0]
tracks_mixed = len(valid_urls)
ctx.log(f"mixdown_tracks: local mixdown uploaded {file_size} bytes")
# Update DB (same for both backends)
async with fresh_db_connection(): async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
@@ -570,12 +641,12 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
transcript, {"audio_location": "storage"} transcript, {"audio_location": "storage"}
) )
ctx.log(f"mixdown_tracks complete: uploaded {file_size} bytes to {storage_path}") ctx.log(f"mixdown_tracks complete: uploaded to {output_key}")
return MixdownResult( return MixdownResult(
audio_key=storage_path, audio_key=output_key,
duration=duration_ms_callback_capture_container[0], duration=duration_ms,
tracks_mixed=len(valid_urls), tracks_mixed=tracks_mixed,
) )

View File

@@ -0,0 +1,89 @@
"""
Modal.com backend for audio mixdown.
Uses Modal's CPU containers to offload audio mixing from Hatchet workers.
Communicates via presigned S3 URLs for both input and output.
"""
import httpx
from pydantic import BaseModel
from reflector.settings import settings
class MixdownResponse(BaseModel):
"""Response from Modal mixdown endpoint."""
duration_ms: float
tracks_mixed: int
audio_uploaded: bool
class AudioMixdownModalProcessor:
"""Audio mixdown processor using Modal.com CPU backend.
Sends track URLs (presigned GET) and output URL (presigned PUT) to Modal.
Modal handles download, mixdown via PyAV, and upload.
"""
def __init__(self, modal_api_key: str | None = None):
if not settings.MIXDOWN_URL:
raise ValueError("MIXDOWN_URL required to use AudioMixdownModalProcessor")
self.mixdown_url = settings.MIXDOWN_URL + "/v1"
self.timeout = settings.MIXDOWN_TIMEOUT
self.modal_api_key = modal_api_key or settings.MIXDOWN_MODAL_API_KEY
if not self.modal_api_key:
raise ValueError(
"MIXDOWN_MODAL_API_KEY required to use AudioMixdownModalProcessor"
)
async def mixdown(
self,
track_urls: list[str],
output_url: str,
target_sample_rate: int,
expected_duration_sec: float | None = None,
) -> MixdownResponse:
"""Mix multiple audio tracks via Modal backend.
Args:
track_urls: List of presigned GET URLs for audio tracks (non-empty)
output_url: Presigned PUT URL for output MP3
target_sample_rate: Sample rate for output (Hz, must be positive)
expected_duration_sec: Optional fallback duration if container metadata unavailable
Returns:
MixdownResponse with duration_ms, tracks_mixed, audio_uploaded
Raises:
ValueError: If track_urls is empty or target_sample_rate invalid
httpx.HTTPStatusError: On HTTP errors (404, 403, 500, etc.)
httpx.TimeoutException: On timeout
"""
# Validate inputs
if not track_urls:
raise ValueError("track_urls cannot be empty")
if target_sample_rate <= 0:
raise ValueError(
f"target_sample_rate must be positive, got {target_sample_rate}"
)
if expected_duration_sec is not None and expected_duration_sec < 0:
raise ValueError(
f"expected_duration_sec cannot be negative, got {expected_duration_sec}"
)
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.mixdown_url}/audio/mixdown",
headers={"Authorization": f"Bearer {self.modal_api_key}"},
json={
"track_urls": track_urls,
"output_url": output_url,
"target_sample_rate": target_sample_rate,
"expected_duration_sec": expected_duration_sec,
},
)
response.raise_for_status()
return MixdownResponse(**response.json())

View File

@@ -98,6 +98,17 @@ class Settings(BaseSettings):
# Diarization: local pyannote.audio # Diarization: local pyannote.audio
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
# Audio Mixdown
# backends:
# - local: in-process PyAV mixdown (runs in same process as Hatchet worker)
# - modal: HTTP API client to Modal.com CPU container
MIXDOWN_BACKEND: str = "local"
MIXDOWN_URL: str | None = None
MIXDOWN_TIMEOUT: int = 900 # 15 minutes
# Mixdown: modal backend
MIXDOWN_MODAL_API_KEY: str | None = None
# Sentry # Sentry
SENTRY_DSN: str | None = None SENTRY_DSN: str | None = None