Files
reflector/gpu/self_hosted/app/routers/mixdown.py
Juan Diego García d164e486cc feat: mixdown modal services + processor pattern (#936)
* allow memory flags and per service config

* feat: mixdown modal services + processor pattern
2026-03-30 17:38:23 -05:00

289 lines
10 KiB
Python

"""
Audio mixdown endpoint for selfhosted GPU service.
CPU-intensive multi-track audio mixing service for combining N audio tracks
into a single MP3 using PyAV amix filter graph.
IMPORTANT: This mixdown logic is duplicated from server/reflector/utils/audio_mixdown.py
for deployment isolation (self_hosted can't import from server/reflector/). If you modify
the PyAV filter graph or mixdown algorithm, you MUST update both:
- gpu/self_hosted/app/routers/mixdown.py (this file)
- server/reflector/utils/audio_mixdown.py
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
"""
import logging
import os
import tempfile
from fractions import Fraction
import av
import requests
from av.audio.resampler import AudioResampler
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from ..auth import apikey_auth
logger = logging.getLogger(__name__)
router = APIRouter(tags=["mixdown"])
S3_TIMEOUT = 120
class MixdownRequest(BaseModel):
track_urls: list[str]
output_url: str
target_sample_rate: int | None = None
offsets_seconds: list[float] | None = None
class MixdownResponse(BaseModel):
size: int
duration_ms: float = 0.0
cancelled: bool = False
@router.post("/mixdown", dependencies=[Depends(apikey_auth)], response_model=MixdownResponse)
def mixdown_tracks(req: MixdownRequest):
"""Mix multiple audio tracks into single MP3 using PyAV amix filter graph."""
valid_urls = [u for u in req.track_urls if u]
if not valid_urls:
raise HTTPException(status_code=400, detail="No valid track URLs provided")
if req.offsets_seconds is not None:
if len(req.offsets_seconds) != len(req.track_urls):
raise HTTPException(
status_code=400,
detail=f"offsets_seconds length ({len(req.offsets_seconds)}) "
f"must match track_urls ({len(req.track_urls)})",
)
if any(o > 18000 for o in req.offsets_seconds):
raise HTTPException(
status_code=400, detail="offsets_seconds exceeds maximum 18000s (5 hours)"
)
if not req.output_url:
raise HTTPException(status_code=400, detail="output_url cannot be empty")
logger.info("Mixdown request: %d tracks", len(valid_urls))
temp_dir = tempfile.mkdtemp()
track_paths = []
output_path = None
try:
# --- Download all tracks ---
for i, url in enumerate(valid_urls):
logger.info("Downloading track %d", i)
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
response.raise_for_status()
track_path = os.path.join(temp_dir, f"track_{i}.webm")
total_bytes = 0
with open(track_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
total_bytes += len(chunk)
track_paths.append(track_path)
logger.info("Track %d downloaded: %d bytes", i, total_bytes)
if not track_paths:
raise HTTPException(status_code=400, detail="No tracks could be downloaded")
# --- Detect sample rate ---
target_sample_rate = req.target_sample_rate
if target_sample_rate is None:
for path in track_paths:
try:
container = av.open(path)
for frame in container.decode(audio=0):
target_sample_rate = frame.sample_rate
container.close()
break
else:
container.close()
continue
break
except Exception:
continue
if target_sample_rate is None:
raise HTTPException(
status_code=400, detail="Could not detect sample rate from any track"
)
logger.info("Target sample rate: %d", target_sample_rate)
# --- Calculate per-input delays ---
input_offsets_seconds = None
if req.offsets_seconds is not None:
input_offsets_seconds = [
req.offsets_seconds[i] for i, url in enumerate(req.track_urls) if url
]
delays_ms = []
if input_offsets_seconds is not None:
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
delays_ms = [max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds]
else:
delays_ms = [0 for _ in track_paths]
# --- Build filter graph ---
# N abuffer -> optional adelay -> amix -> aformat -> abuffersink
graph = av.filter.Graph()
inputs = []
for idx in range(len(track_paths)):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
for idx, in_ctx in enumerate(inputs):
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
if delay_ms > 0:
adelay = graph.add(
"adelay",
args=f"delays={delay_ms}|{delay_ms}:all=1",
name=f"delay{idx}",
)
in_ctx.link_to(adelay)
adelay.link_to(mixer, 0, idx)
else:
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# --- Open all containers and decode ---
containers = []
output_path = os.path.join(temp_dir, "mixed.mp3")
try:
for path in track_paths:
containers.append(av.open(path))
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
# Open output MP3
out_container = av.open(output_path, "w", format="mp3")
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
total_duration = 0
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None)
continue
if frame.sample_rate != target_sample_rate:
continue
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush filter graph
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
for packet in out_stream.encode(mixed):
out_container.mux(packet)
total_duration += packet.duration
# Flush encoder
for packet in out_stream.encode(None):
out_container.mux(packet)
total_duration += packet.duration
# Calculate duration in ms
last_tb = out_stream.time_base
duration_ms = 0.0
if last_tb and total_duration > 0:
duration_ms = round(float(total_duration * last_tb * 1000), 2)
out_container.close()
finally:
for c in containers:
try:
c.close()
except Exception:
pass
file_size = os.path.getsize(output_path)
logger.info("Mixdown complete: %d bytes, %.2fms", file_size, duration_ms)
# --- Upload result ---
logger.info("Uploading mixed audio to S3")
with open(output_path, "rb") as f:
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
upload_response.raise_for_status()
logger.info("Upload complete: %d bytes", file_size)
return MixdownResponse(size=file_size, duration_ms=duration_ms)
except HTTPException:
raise
except Exception as e:
logger.error("Mixdown failed: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Mixdown failed: {e}") from e
finally:
for path in track_paths:
if os.path.exists(path):
try:
os.unlink(path)
except Exception as e:
logger.warning("Failed to cleanup track file: %s", e)
if output_path and os.path.exists(output_path):
try:
os.unlink(output_path)
except Exception as e:
logger.warning("Failed to cleanup output file: %s", e)
try:
os.rmdir(temp_dir)
except Exception as e:
logger.warning("Failed to cleanup temp directory: %s", e)