mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
380 lines
13 KiB
Python
380 lines
13 KiB
Python
"""
|
|
Reflector GPU backend - audio mixdown
|
|
======================================
|
|
|
|
CPU-intensive audio mixdown service for combining multiple audio tracks.
|
|
Uses PyAV filter graph (amix) for high-quality audio mixing.
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from fractions import Fraction
|
|
|
|
import modal
|
|
|
|
MIXDOWN_TIMEOUT = 900 # 15 minutes
|
|
SCALEDOWN_WINDOW = 60 # 1 minute idle before shutdown
|
|
|
|
app = modal.App("reflector-mixdown")
|
|
|
|
# CPU-based image (no GPU needed for audio processing)
|
|
image = (
|
|
modal.Image.debian_slim(python_version="3.12")
|
|
.apt_install("ffmpeg") # Required by PyAV
|
|
.pip_install(
|
|
"av==13.1.0", # PyAV for audio processing
|
|
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
|
|
"fastapi==0.115.12", # API framework
|
|
)
|
|
)
|
|
|
|
|
|
@app.function(
|
|
cpu=4.0, # 4 CPU cores for audio processing
|
|
timeout=MIXDOWN_TIMEOUT,
|
|
scaledown_window=SCALEDOWN_WINDOW,
|
|
secrets=[modal.Secret.from_name("reflector-gpu")],
|
|
image=image,
|
|
)
|
|
@modal.concurrent(max_inputs=10)
|
|
@modal.asgi_app()
|
|
def web():
|
|
import logging
|
|
import secrets
|
|
import shutil
|
|
|
|
import av
|
|
import requests
|
|
from av.audio.resampler import AudioResampler
|
|
from fastapi import Depends, FastAPI, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from pydantic import BaseModel
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI()
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
# Validate API key exists at startup
|
|
API_KEY = os.environ.get("REFLECTOR_GPU_APIKEY")
|
|
if not API_KEY:
|
|
raise RuntimeError("REFLECTOR_GPU_APIKEY not configured in Modal secrets")
|
|
|
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
# Use constant-time comparison to prevent timing attacks
|
|
if secrets.compare_digest(apikey, API_KEY):
|
|
return
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
class MixdownRequest(BaseModel):
|
|
track_urls: list[str]
|
|
output_url: str
|
|
target_sample_rate: int = 48000
|
|
expected_duration_sec: float | None = None
|
|
|
|
class MixdownResponse(BaseModel):
|
|
duration_ms: float
|
|
tracks_mixed: int
|
|
audio_uploaded: bool
|
|
|
|
def download_track(url: str, temp_dir: str, index: int) -> str:
|
|
"""Download track from presigned URL to temp file using streaming."""
|
|
logger.info(f"Downloading track {index + 1}")
|
|
response = requests.get(url, stream=True, timeout=300)
|
|
|
|
if response.status_code == 404:
|
|
raise HTTPException(status_code=404, detail=f"Track {index} not found")
|
|
if response.status_code == 403:
|
|
raise HTTPException(
|
|
status_code=403, detail=f"Track {index} presigned URL expired"
|
|
)
|
|
|
|
response.raise_for_status()
|
|
|
|
temp_path = os.path.join(temp_dir, f"track_{index}.webm")
|
|
total_bytes = 0
|
|
with open(temp_path, "wb") as f:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
f.write(chunk)
|
|
total_bytes += len(chunk)
|
|
|
|
logger.info(f"Track {index + 1} downloaded: {total_bytes} bytes")
|
|
return temp_path
|
|
|
|
def mixdown_tracks_modal(
|
|
track_paths: list[str],
|
|
output_path: str,
|
|
target_sample_rate: int,
|
|
expected_duration_sec: float | None,
|
|
logger,
|
|
) -> float:
|
|
"""Mix multiple audio tracks using PyAV filter graph.
|
|
|
|
Args:
|
|
track_paths: List of local file paths to audio tracks
|
|
output_path: Local path for output MP3 file
|
|
target_sample_rate: Sample rate for output (Hz)
|
|
expected_duration_sec: Optional fallback duration if container metadata unavailable
|
|
logger: Logger instance for progress tracking
|
|
|
|
Returns:
|
|
Duration in milliseconds
|
|
"""
|
|
logger.info(f"Starting mixdown of {len(track_paths)} tracks")
|
|
|
|
# Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
|
|
graph = av.filter.Graph()
|
|
inputs = []
|
|
|
|
for idx in range(len(track_paths)):
|
|
args = (
|
|
f"time_base=1/{target_sample_rate}:"
|
|
f"sample_rate={target_sample_rate}:"
|
|
f"sample_fmt=s32:"
|
|
f"channel_layout=stereo"
|
|
)
|
|
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
|
|
inputs.append(in_ctx)
|
|
|
|
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
|
|
fmt = graph.add(
|
|
"aformat",
|
|
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
|
|
name="fmt",
|
|
)
|
|
sink = graph.add("abuffersink", name="out")
|
|
|
|
# Connect inputs to mixer (no delays for Modal implementation)
|
|
for idx, in_ctx in enumerate(inputs):
|
|
in_ctx.link_to(mixer, 0, idx)
|
|
|
|
mixer.link_to(fmt)
|
|
fmt.link_to(sink)
|
|
graph.configure()
|
|
|
|
# Open all containers
|
|
containers = []
|
|
try:
|
|
for i, path in enumerate(track_paths):
|
|
try:
|
|
c = av.open(path)
|
|
containers.append(c)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to open container {i}: {e}",
|
|
)
|
|
|
|
if not containers:
|
|
raise ValueError("Could not open any track containers")
|
|
|
|
# Calculate total duration for progress reporting
|
|
max_duration_sec = 0.0
|
|
for c in containers:
|
|
if c.duration is not None:
|
|
dur_sec = c.duration / av.time_base
|
|
max_duration_sec = max(max_duration_sec, dur_sec)
|
|
if max_duration_sec == 0.0 and expected_duration_sec:
|
|
max_duration_sec = expected_duration_sec
|
|
|
|
# Setup output container
|
|
out_container = av.open(output_path, "w", format="mp3")
|
|
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
|
|
|
|
decoders = [c.decode(audio=0) for c in containers]
|
|
active = [True] * len(decoders)
|
|
resamplers = [
|
|
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
|
|
for _ in decoders
|
|
]
|
|
|
|
current_max_time = 0.0
|
|
last_log_time = time.monotonic()
|
|
start_time = time.monotonic()
|
|
|
|
total_duration = 0
|
|
|
|
while any(active):
|
|
for i, (dec, is_active) in enumerate(zip(decoders, active)):
|
|
if not is_active:
|
|
continue
|
|
try:
|
|
frame = next(dec)
|
|
except StopIteration:
|
|
active[i] = False
|
|
inputs[i].push(None) # Signal end of stream
|
|
continue
|
|
|
|
if frame.sample_rate != target_sample_rate:
|
|
continue
|
|
|
|
# Progress logging (every 5 seconds)
|
|
if frame.time is not None:
|
|
current_max_time = max(current_max_time, frame.time)
|
|
now = time.monotonic()
|
|
if now - last_log_time >= 5.0:
|
|
elapsed = now - start_time
|
|
if max_duration_sec > 0:
|
|
progress_pct = min(
|
|
100.0, (current_max_time / max_duration_sec) * 100
|
|
)
|
|
logger.info(
|
|
f"Mixdown progress: {progress_pct:.1f}% @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Mixdown progress: @ {current_max_time:.1f}s (elapsed: {elapsed:.1f}s)"
|
|
)
|
|
last_log_time = now
|
|
|
|
out_frames = resamplers[i].resample(frame) or []
|
|
for rf in out_frames:
|
|
rf.sample_rate = target_sample_rate
|
|
rf.time_base = Fraction(1, target_sample_rate)
|
|
inputs[i].push(rf)
|
|
|
|
# Pull mixed frames from sink and encode
|
|
while True:
|
|
try:
|
|
mixed = sink.pull()
|
|
except Exception:
|
|
break
|
|
mixed.sample_rate = target_sample_rate
|
|
mixed.time_base = Fraction(1, target_sample_rate)
|
|
|
|
# Encode and mux
|
|
for packet in out_stream.encode(mixed):
|
|
out_container.mux(packet)
|
|
total_duration += packet.duration
|
|
|
|
# Flush remaining frames from filter graph
|
|
while True:
|
|
try:
|
|
mixed = sink.pull()
|
|
except Exception:
|
|
break
|
|
mixed.sample_rate = target_sample_rate
|
|
mixed.time_base = Fraction(1, target_sample_rate)
|
|
|
|
for packet in out_stream.encode(mixed):
|
|
out_container.mux(packet)
|
|
total_duration += packet.duration
|
|
|
|
# Flush encoder
|
|
for packet in out_stream.encode():
|
|
out_container.mux(packet)
|
|
total_duration += packet.duration
|
|
|
|
# Calculate duration in milliseconds
|
|
if total_duration > 0:
|
|
# Use the same calculation as AudioFileWriterProcessor
|
|
duration_ms = round(
|
|
float(total_duration * out_stream.time_base * 1000), 2
|
|
)
|
|
else:
|
|
duration_ms = 0.0
|
|
|
|
out_container.close()
|
|
logger.info(f"Mixdown complete: duration={duration_ms}ms")
|
|
|
|
finally:
|
|
# Cleanup all containers
|
|
for c in containers:
|
|
if c is not None:
|
|
try:
|
|
c.close()
|
|
except Exception:
|
|
pass
|
|
|
|
return duration_ms
|
|
|
|
@app.post("/v1/audio/mixdown", dependencies=[Depends(apikey_auth)])
|
|
def mixdown(request: MixdownRequest) -> MixdownResponse:
|
|
"""Mix multiple audio tracks into a single MP3 file.
|
|
|
|
Tracks are downloaded from presigned S3 URLs, mixed using PyAV,
|
|
and uploaded to a presigned S3 PUT URL.
|
|
"""
|
|
if not request.track_urls:
|
|
raise HTTPException(status_code=400, detail="No track URLs provided")
|
|
|
|
logger.info(f"Mixdown request: {len(request.track_urls)} tracks")
|
|
|
|
temp_dir = tempfile.mkdtemp()
|
|
temp_files = []
|
|
output_mp3_path = None
|
|
|
|
try:
|
|
# Download all tracks
|
|
for i, url in enumerate(request.track_urls):
|
|
temp_path = download_track(url, temp_dir, i)
|
|
temp_files.append(temp_path)
|
|
|
|
# Mix tracks
|
|
output_mp3_path = os.path.join(temp_dir, "mixed.mp3")
|
|
duration_ms = mixdown_tracks_modal(
|
|
temp_files,
|
|
output_mp3_path,
|
|
request.target_sample_rate,
|
|
request.expected_duration_sec,
|
|
logger,
|
|
)
|
|
|
|
# Upload result to S3
|
|
logger.info("Uploading result to S3")
|
|
file_size = os.path.getsize(output_mp3_path)
|
|
with open(output_mp3_path, "rb") as f:
|
|
upload_response = requests.put(
|
|
request.output_url, data=f, timeout=300
|
|
)
|
|
|
|
if upload_response.status_code == 403:
|
|
raise HTTPException(
|
|
status_code=403, detail="Output presigned URL expired"
|
|
)
|
|
|
|
upload_response.raise_for_status()
|
|
logger.info(f"Upload complete: {file_size} bytes")
|
|
|
|
return MixdownResponse(
|
|
duration_ms=duration_ms,
|
|
tracks_mixed=len(request.track_urls),
|
|
audio_uploaded=True,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Mixdown failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Mixdown failed: {str(e)}")
|
|
|
|
finally:
|
|
# Cleanup temp files
|
|
for temp_path in temp_files:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup temp file {temp_path}: {e}")
|
|
|
|
if output_mp3_path and os.path.exists(output_mp3_path):
|
|
try:
|
|
os.unlink(output_mp3_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup output file {output_mp3_path}: {e}")
|
|
|
|
try:
|
|
shutil.rmtree(temp_dir)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup temp directory {temp_dir}: {e}")
|
|
|
|
return app
|