feat: mixdown modal services + processor pattern (#936)

* allow memory flags and per service config

* feat: mixdown modal services + processor pattern
This commit is contained in:
Juan Diego García
2026-03-30 17:38:23 -05:00
committed by GitHub
parent 12bf0c2d77
commit d164e486cc
15 changed files with 1353 additions and 104 deletions

View File

@@ -132,13 +132,22 @@ fi
echo " -> $DIARIZER_URL"
echo ""
echo "Deploying padding (CPU audio processing via Modal SDK)..."
modal deploy reflector_padding.py
if [ $? -ne 0 ]; then
echo "Deploying padding (CPU audio processing)..."
PADDING_URL=$(modal deploy reflector_padding.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
if [ -z "$PADDING_URL" ]; then
echo "Error: Failed to deploy padding. Check Modal dashboard for details."
exit 1
fi
echo " -> reflector-padding.pad_track (Modal SDK function)"
echo " -> $PADDING_URL"
echo ""
echo "Deploying mixdown (CPU multi-track audio mixing)..."
MIXDOWN_URL=$(modal deploy reflector_mixdown.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
if [ -z "$MIXDOWN_URL" ]; then
echo "Error: Failed to deploy mixdown. Check Modal dashboard for details."
exit 1
fi
echo " -> $MIXDOWN_URL"
# --- Output Configuration ---
echo ""
@@ -157,5 +166,11 @@ echo "DIARIZATION_BACKEND=modal"
echo "DIARIZATION_URL=$DIARIZER_URL"
echo "DIARIZATION_MODAL_API_KEY=$API_KEY"
echo ""
echo "# Padding uses Modal SDK (requires MODAL_TOKEN_ID/SECRET in worker containers)"
echo "PADDING_BACKEND=modal"
echo "PADDING_URL=$PADDING_URL"
echo "PADDING_MODAL_API_KEY=$API_KEY"
echo ""
echo "MIXDOWN_BACKEND=modal"
echo "MIXDOWN_URL=$MIXDOWN_URL"
echo "MIXDOWN_MODAL_API_KEY=$API_KEY"
echo "# --- End Modal Configuration ---"

View File

@@ -0,0 +1,385 @@
"""
Reflector GPU backend - audio mixdown
=====================================
CPU-intensive multi-track audio mixdown service.
Mixes N audio tracks into a single MP3 using PyAV amix filter graph.
IMPORTANT: This mixdown logic is duplicated from server/reflector/utils/audio_mixdown.py
for Modal deployment isolation (Modal can't import from server/reflector/). If you modify
the PyAV filter graph or mixdown algorithm, you MUST update both:
- gpu/modal_deployments/reflector_mixdown.py (this file)
- server/reflector/utils/audio_mixdown.py
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
"""
import os
import tempfile
from fractions import Fraction
import asyncio
import modal
S3_TIMEOUT = 120 # Higher than padding (60s) — multiple track downloads
MIXDOWN_TIMEOUT = 1200 + (S3_TIMEOUT * 2) # 1440s total
SCALEDOWN_WINDOW = 60
DISCONNECT_CHECK_INTERVAL = 2
app = modal.App("reflector-mixdown")
# CPU-based image (mixdown is CPU-bound, no GPU needed)
image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("ffmpeg") # Required by PyAV
.pip_install(
"av==13.1.0", # PyAV for audio processing
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
"fastapi==0.115.12", # API framework
)
)
@app.function(
cpu=4.0, # Higher than padding (2.0) for multi-track mixing
timeout=MIXDOWN_TIMEOUT,
scaledown_window=SCALEDOWN_WINDOW,
image=image,
secrets=[modal.Secret.from_name("reflector-gpu")],
)
@modal.asgi_app()
def web():
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
class MixdownRequest(BaseModel):
track_urls: list[str]
output_url: str
target_sample_rate: int | None = None
offsets_seconds: list[float] | None = None
class MixdownResponse(BaseModel):
size: int
duration_ms: float = 0.0
cancelled: bool = False
web_app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
@web_app.post("/mixdown", dependencies=[Depends(apikey_auth)])
async def mixdown_endpoint(request: Request, req: MixdownRequest) -> MixdownResponse:
"""Modal web endpoint for mixing audio tracks with disconnect detection."""
import logging
import threading
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
valid_urls = [u for u in req.track_urls if u]
if not valid_urls:
raise HTTPException(status_code=400, detail="No valid track URLs provided")
if req.offsets_seconds is not None:
if len(req.offsets_seconds) != len(req.track_urls):
raise HTTPException(
status_code=400,
detail=f"offsets_seconds length ({len(req.offsets_seconds)}) "
f"must match track_urls ({len(req.track_urls)})",
)
if any(o > 18000 for o in req.offsets_seconds):
raise HTTPException(status_code=400, detail="offsets_seconds exceeds maximum 18000s (5 hours)")
if not req.output_url:
raise HTTPException(status_code=400, detail="output_url cannot be empty")
logger.info(f"Mixdown request: {len(valid_urls)} tracks")
# Thread-safe cancellation flag
cancelled = threading.Event()
async def check_disconnect():
"""Background task to check for client disconnect."""
while not cancelled.is_set():
await asyncio.sleep(DISCONNECT_CHECK_INTERVAL)
if await request.is_disconnected():
logger.warning("Client disconnected, setting cancellation flag")
cancelled.set()
break
disconnect_task = asyncio.create_task(check_disconnect())
try:
result = await asyncio.get_event_loop().run_in_executor(
None, _mixdown_tracks_blocking, req, cancelled, logger
)
return MixdownResponse(**result)
finally:
cancelled.set()
disconnect_task.cancel()
try:
await disconnect_task
except asyncio.CancelledError:
pass
def _mixdown_tracks_blocking(req, cancelled, logger) -> dict:
"""Blocking CPU-bound mixdown work with periodic cancellation checks.
Downloads all tracks, builds PyAV amix filter graph, encodes to MP3,
and uploads the result to the presigned output URL.
"""
import av
import requests
from av.audio.resampler import AudioResampler
import time
temp_dir = tempfile.mkdtemp()
track_paths = []
output_path = None
last_check = time.time()
try:
# --- Download all tracks ---
valid_urls = [u for u in req.track_urls if u]
for i, url in enumerate(valid_urls):
if cancelled.is_set():
logger.info("Cancelled during download phase")
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
logger.info(f"Downloading track {i}")
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
response.raise_for_status()
track_path = os.path.join(temp_dir, f"track_{i}.webm")
total_bytes = 0
chunk_count = 0
with open(track_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
total_bytes += len(chunk)
chunk_count += 1
if chunk_count % 12 == 0:
now = time.time()
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
if cancelled.is_set():
logger.info(f"Cancelled during track {i} download")
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
last_check = now
track_paths.append(track_path)
logger.info(f"Track {i} downloaded: {total_bytes} bytes")
if not track_paths:
raise ValueError("No tracks downloaded")
# --- Detect sample rate ---
target_sample_rate = req.target_sample_rate
if target_sample_rate is None:
for path in track_paths:
try:
container = av.open(path)
for frame in container.decode(audio=0):
target_sample_rate = frame.sample_rate
container.close()
break
else:
container.close()
continue
break
except Exception:
continue
if target_sample_rate is None:
raise ValueError("Could not detect sample rate from any track")
logger.info(f"Target sample rate: {target_sample_rate}")
# --- Calculate per-input delays ---
input_offsets_seconds = None
if req.offsets_seconds is not None:
input_offsets_seconds = [
req.offsets_seconds[i] for i, url in enumerate(req.track_urls) if url
]
delays_ms = []
if input_offsets_seconds is not None:
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
delays_ms = [max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds]
else:
delays_ms = [0 for _ in track_paths]
# --- Build filter graph ---
# N abuffer -> optional adelay -> amix -> aformat -> abuffersink
graph = av.filter.Graph()
inputs = []
for idx in range(len(track_paths)):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
for idx, in_ctx in enumerate(inputs):
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
if delay_ms > 0:
adelay = graph.add(
"adelay",
args=f"delays={delay_ms}|{delay_ms}:all=1",
name=f"delay{idx}",
)
in_ctx.link_to(adelay)
adelay.link_to(mixer, 0, idx)
else:
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# --- Open all containers and decode ---
containers = []
output_path = os.path.join(temp_dir, "mixed.mp3")
try:
for path in track_paths:
containers.append(av.open(path))
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
# Open output MP3
out_container = av.open(output_path, "w", format="mp3")
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
total_duration = 0
while any(active):
# Check cancellation periodically
now = time.time()
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
if cancelled.is_set():
logger.info("Cancelled during mixing")
out_container.close()
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
last_check = now
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None)
continue
if frame.sample_rate != target_sample_rate:
continue
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush filter graph
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush encoder
for packet in out_stream.encode(None):
out_container.mux(packet)
total_duration += packet.duration
# Calculate duration in ms
last_tb = out_stream.time_base
duration_ms = 0.0
if last_tb and total_duration > 0:
duration_ms = round(float(total_duration * last_tb * 1000), 2)
out_container.close()
finally:
for c in containers:
try:
c.close()
except Exception:
pass
file_size = os.path.getsize(output_path)
logger.info(f"Mixdown complete: {file_size} bytes, {duration_ms}ms")
if cancelled.is_set():
logger.info("Cancelled after mixing, before upload")
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
# --- Upload result ---
logger.info("Uploading mixed audio to S3")
with open(output_path, "rb") as f:
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
upload_response.raise_for_status()
logger.info(f"Upload complete: {file_size} bytes")
return {"size": file_size, "duration_ms": duration_ms}
finally:
# Cleanup all temp files
for path in track_paths:
if os.path.exists(path):
try:
os.unlink(path)
except Exception as e:
logger.warning(f"Failed to cleanup track file: {e}")
if output_path and os.path.exists(output_path):
try:
os.unlink(output_path)
except Exception as e:
logger.warning(f"Failed to cleanup output file: {e}")
try:
os.rmdir(temp_dir)
except Exception as e:
logger.warning(f"Failed to cleanup temp directory: {e}")
return web_app

View File

@@ -52,10 +52,12 @@ OPUS_DEFAULT_BIT_RATE = 128000
timeout=PADDING_TIMEOUT,
scaledown_window=SCALEDOWN_WINDOW,
image=image,
secrets=[modal.Secret.from_name("reflector-gpu")],
)
@modal.asgi_app()
def web():
from fastapi import FastAPI, Request, HTTPException
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
class PaddingRequest(BaseModel):
@@ -70,7 +72,18 @@ def web():
web_app = FastAPI()
@web_app.post("/pad")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
@web_app.post("/pad", dependencies=[Depends(apikey_auth)])
async def pad_track_endpoint(request: Request, req: PaddingRequest) -> PaddingResponse:
"""Modal web endpoint for padding audio tracks with disconnect detection.
"""

View File

@@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from .routers.diarization import router as diarization_router
from .routers.mixdown import router as mixdown_router
from .routers.padding import router as padding_router
from .routers.transcription import router as transcription_router
from .routers.translation import router as translation_router
@@ -29,4 +30,5 @@ def create_app() -> FastAPI:
app.include_router(translation_router)
app.include_router(diarization_router)
app.include_router(padding_router)
app.include_router(mixdown_router)
return app

View File

@@ -0,0 +1,288 @@
"""
Audio mixdown endpoint for selfhosted GPU service.
CPU-intensive multi-track audio mixing service for combining N audio tracks
into a single MP3 using PyAV amix filter graph.
IMPORTANT: This mixdown logic is duplicated from server/reflector/utils/audio_mixdown.py
for deployment isolation (self_hosted can't import from server/reflector/). If you modify
the PyAV filter graph or mixdown algorithm, you MUST update both:
- gpu/self_hosted/app/routers/mixdown.py (this file)
- server/reflector/utils/audio_mixdown.py
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
"""
import logging
import os
import tempfile
from fractions import Fraction
import av
import requests
from av.audio.resampler import AudioResampler
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from ..auth import apikey_auth
logger = logging.getLogger(__name__)
router = APIRouter(tags=["mixdown"])
S3_TIMEOUT = 120
class MixdownRequest(BaseModel):
track_urls: list[str]
output_url: str
target_sample_rate: int | None = None
offsets_seconds: list[float] | None = None
class MixdownResponse(BaseModel):
size: int
duration_ms: float = 0.0
cancelled: bool = False
@router.post("/mixdown", dependencies=[Depends(apikey_auth)], response_model=MixdownResponse)
def mixdown_tracks(req: MixdownRequest):
"""Mix multiple audio tracks into single MP3 using PyAV amix filter graph."""
valid_urls = [u for u in req.track_urls if u]
if not valid_urls:
raise HTTPException(status_code=400, detail="No valid track URLs provided")
if req.offsets_seconds is not None:
if len(req.offsets_seconds) != len(req.track_urls):
raise HTTPException(
status_code=400,
detail=f"offsets_seconds length ({len(req.offsets_seconds)}) "
f"must match track_urls ({len(req.track_urls)})",
)
if any(o > 18000 for o in req.offsets_seconds):
raise HTTPException(
status_code=400, detail="offsets_seconds exceeds maximum 18000s (5 hours)"
)
if not req.output_url:
raise HTTPException(status_code=400, detail="output_url cannot be empty")
logger.info("Mixdown request: %d tracks", len(valid_urls))
temp_dir = tempfile.mkdtemp()
track_paths = []
output_path = None
try:
# --- Download all tracks ---
for i, url in enumerate(valid_urls):
logger.info("Downloading track %d", i)
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
response.raise_for_status()
track_path = os.path.join(temp_dir, f"track_{i}.webm")
total_bytes = 0
with open(track_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
total_bytes += len(chunk)
track_paths.append(track_path)
logger.info("Track %d downloaded: %d bytes", i, total_bytes)
if not track_paths:
raise HTTPException(status_code=400, detail="No tracks could be downloaded")
# --- Detect sample rate ---
target_sample_rate = req.target_sample_rate
if target_sample_rate is None:
for path in track_paths:
try:
container = av.open(path)
for frame in container.decode(audio=0):
target_sample_rate = frame.sample_rate
container.close()
break
else:
container.close()
continue
break
except Exception:
continue
if target_sample_rate is None:
raise HTTPException(
status_code=400, detail="Could not detect sample rate from any track"
)
logger.info("Target sample rate: %d", target_sample_rate)
# --- Calculate per-input delays ---
input_offsets_seconds = None
if req.offsets_seconds is not None:
input_offsets_seconds = [
req.offsets_seconds[i] for i, url in enumerate(req.track_urls) if url
]
delays_ms = []
if input_offsets_seconds is not None:
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
delays_ms = [max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds]
else:
delays_ms = [0 for _ in track_paths]
# --- Build filter graph ---
# N abuffer -> optional adelay -> amix -> aformat -> abuffersink
graph = av.filter.Graph()
inputs = []
for idx in range(len(track_paths)):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
for idx, in_ctx in enumerate(inputs):
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
if delay_ms > 0:
adelay = graph.add(
"adelay",
args=f"delays={delay_ms}|{delay_ms}:all=1",
name=f"delay{idx}",
)
in_ctx.link_to(adelay)
adelay.link_to(mixer, 0, idx)
else:
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# --- Open all containers and decode ---
containers = []
output_path = os.path.join(temp_dir, "mixed.mp3")
try:
for path in track_paths:
containers.append(av.open(path))
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
# Open output MP3
out_container = av.open(output_path, "w", format="mp3")
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
total_duration = 0
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None)
continue
if frame.sample_rate != target_sample_rate:
continue
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush filter graph
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush encoder
for packet in out_stream.encode(None):
out_container.mux(packet)
total_duration += packet.duration
# Calculate duration in ms
last_tb = out_stream.time_base
duration_ms = 0.0
if last_tb and total_duration > 0:
duration_ms = round(float(total_duration * last_tb * 1000), 2)
out_container.close()
finally:
for c in containers:
try:
c.close()
except Exception:
pass
file_size = os.path.getsize(output_path)
logger.info("Mixdown complete: %d bytes, %.2fms", file_size, duration_ms)
# --- Upload result ---
logger.info("Uploading mixed audio to S3")
with open(output_path, "rb") as f:
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
upload_response.raise_for_status()
logger.info("Upload complete: %d bytes", file_size)
return MixdownResponse(size=file_size, duration_ms=duration_ms)
except HTTPException:
raise
except Exception as e:
logger.error("Mixdown failed: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Mixdown failed: {e}") from e
finally:
for path in track_paths:
if os.path.exists(path):
try:
os.unlink(path)
except Exception as e:
logger.warning("Failed to cleanup track file: %s", e)
if output_path and os.path.exists(output_path):
try:
os.unlink(output_path)
except Exception as e:
logger.warning("Failed to cleanup output file: %s", e)
try:
os.rmdir(temp_dir)
except Exception as e:
logger.warning("Failed to cleanup temp directory: %s", e)