mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-04 12:56:49 +00:00
feat: mixdown modal services + processor pattern (#936)
* allow memory flags and per service config * feat: mixdown modal services + processor pattern
This commit is contained in:
committed by
GitHub
parent
12bf0c2d77
commit
d164e486cc
@@ -132,13 +132,22 @@ fi
|
||||
echo " -> $DIARIZER_URL"
|
||||
|
||||
echo ""
|
||||
echo "Deploying padding (CPU audio processing via Modal SDK)..."
|
||||
modal deploy reflector_padding.py
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Deploying padding (CPU audio processing)..."
|
||||
PADDING_URL=$(modal deploy reflector_padding.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
|
||||
if [ -z "$PADDING_URL" ]; then
|
||||
echo "Error: Failed to deploy padding. Check Modal dashboard for details."
|
||||
exit 1
|
||||
fi
|
||||
echo " -> reflector-padding.pad_track (Modal SDK function)"
|
||||
echo " -> $PADDING_URL"
|
||||
|
||||
echo ""
|
||||
echo "Deploying mixdown (CPU multi-track audio mixing)..."
|
||||
MIXDOWN_URL=$(modal deploy reflector_mixdown.py 2>&1 | grep -o 'https://[^ ]*web.modal.run' | head -1)
|
||||
if [ -z "$MIXDOWN_URL" ]; then
|
||||
echo "Error: Failed to deploy mixdown. Check Modal dashboard for details."
|
||||
exit 1
|
||||
fi
|
||||
echo " -> $MIXDOWN_URL"
|
||||
|
||||
# --- Output Configuration ---
|
||||
echo ""
|
||||
@@ -157,5 +166,11 @@ echo "DIARIZATION_BACKEND=modal"
|
||||
echo "DIARIZATION_URL=$DIARIZER_URL"
|
||||
echo "DIARIZATION_MODAL_API_KEY=$API_KEY"
|
||||
echo ""
|
||||
echo "# Padding uses Modal SDK (requires MODAL_TOKEN_ID/SECRET in worker containers)"
|
||||
echo "PADDING_BACKEND=modal"
|
||||
echo "PADDING_URL=$PADDING_URL"
|
||||
echo "PADDING_MODAL_API_KEY=$API_KEY"
|
||||
echo ""
|
||||
echo "MIXDOWN_BACKEND=modal"
|
||||
echo "MIXDOWN_URL=$MIXDOWN_URL"
|
||||
echo "MIXDOWN_MODAL_API_KEY=$API_KEY"
|
||||
echo "# --- End Modal Configuration ---"
|
||||
|
||||
385
gpu/modal_deployments/reflector_mixdown.py
Normal file
385
gpu/modal_deployments/reflector_mixdown.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
Reflector GPU backend - audio mixdown
|
||||
=====================================
|
||||
|
||||
CPU-intensive multi-track audio mixdown service.
|
||||
Mixes N audio tracks into a single MP3 using PyAV amix filter graph.
|
||||
|
||||
IMPORTANT: This mixdown logic is duplicated from server/reflector/utils/audio_mixdown.py
|
||||
for Modal deployment isolation (Modal can't import from server/reflector/). If you modify
|
||||
the PyAV filter graph or mixdown algorithm, you MUST update both:
|
||||
- gpu/modal_deployments/reflector_mixdown.py (this file)
|
||||
- server/reflector/utils/audio_mixdown.py
|
||||
|
||||
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from fractions import Fraction
|
||||
import asyncio
|
||||
|
||||
import modal
|
||||
|
||||
S3_TIMEOUT = 120 # Higher than padding (60s) — multiple track downloads
|
||||
MIXDOWN_TIMEOUT = 1200 + (S3_TIMEOUT * 2) # 1440s total
|
||||
SCALEDOWN_WINDOW = 60
|
||||
DISCONNECT_CHECK_INTERVAL = 2
|
||||
|
||||
app = modal.App("reflector-mixdown")
|
||||
|
||||
# CPU-based image (mixdown is CPU-bound, no GPU needed)
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.apt_install("ffmpeg") # Required by PyAV
|
||||
.pip_install(
|
||||
"av==13.1.0", # PyAV for audio processing
|
||||
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
|
||||
"fastapi==0.115.12", # API framework
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
cpu=4.0, # Higher than padding (2.0) for multi-track mixing
|
||||
timeout=MIXDOWN_TIMEOUT,
|
||||
scaledown_window=SCALEDOWN_WINDOW,
|
||||
image=image,
|
||||
secrets=[modal.Secret.from_name("reflector-gpu")],
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
class MixdownRequest(BaseModel):
|
||||
track_urls: list[str]
|
||||
output_url: str
|
||||
target_sample_rate: int | None = None
|
||||
offsets_seconds: list[float] | None = None
|
||||
|
||||
class MixdownResponse(BaseModel):
|
||||
size: int
|
||||
duration_ms: float = 0.0
|
||||
cancelled: bool = False
|
||||
|
||||
web_app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
@web_app.post("/mixdown", dependencies=[Depends(apikey_auth)])
|
||||
async def mixdown_endpoint(request: Request, req: MixdownRequest) -> MixdownResponse:
|
||||
"""Modal web endpoint for mixing audio tracks with disconnect detection."""
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
valid_urls = [u for u in req.track_urls if u]
|
||||
if not valid_urls:
|
||||
raise HTTPException(status_code=400, detail="No valid track URLs provided")
|
||||
if req.offsets_seconds is not None:
|
||||
if len(req.offsets_seconds) != len(req.track_urls):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"offsets_seconds length ({len(req.offsets_seconds)}) "
|
||||
f"must match track_urls ({len(req.track_urls)})",
|
||||
)
|
||||
if any(o > 18000 for o in req.offsets_seconds):
|
||||
raise HTTPException(status_code=400, detail="offsets_seconds exceeds maximum 18000s (5 hours)")
|
||||
if not req.output_url:
|
||||
raise HTTPException(status_code=400, detail="output_url cannot be empty")
|
||||
|
||||
logger.info(f"Mixdown request: {len(valid_urls)} tracks")
|
||||
|
||||
# Thread-safe cancellation flag
|
||||
cancelled = threading.Event()
|
||||
|
||||
async def check_disconnect():
|
||||
"""Background task to check for client disconnect."""
|
||||
while not cancelled.is_set():
|
||||
await asyncio.sleep(DISCONNECT_CHECK_INTERVAL)
|
||||
if await request.is_disconnected():
|
||||
logger.warning("Client disconnected, setting cancellation flag")
|
||||
cancelled.set()
|
||||
break
|
||||
|
||||
disconnect_task = asyncio.create_task(check_disconnect())
|
||||
|
||||
try:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None, _mixdown_tracks_blocking, req, cancelled, logger
|
||||
)
|
||||
return MixdownResponse(**result)
|
||||
finally:
|
||||
cancelled.set()
|
||||
disconnect_task.cancel()
|
||||
try:
|
||||
await disconnect_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _mixdown_tracks_blocking(req, cancelled, logger) -> dict:
|
||||
"""Blocking CPU-bound mixdown work with periodic cancellation checks.
|
||||
|
||||
Downloads all tracks, builds PyAV amix filter graph, encodes to MP3,
|
||||
and uploads the result to the presigned output URL.
|
||||
"""
|
||||
import av
|
||||
import requests
|
||||
from av.audio.resampler import AudioResampler
|
||||
import time
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
track_paths = []
|
||||
output_path = None
|
||||
last_check = time.time()
|
||||
|
||||
try:
|
||||
# --- Download all tracks ---
|
||||
valid_urls = [u for u in req.track_urls if u]
|
||||
for i, url in enumerate(valid_urls):
|
||||
if cancelled.is_set():
|
||||
logger.info("Cancelled during download phase")
|
||||
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
|
||||
|
||||
logger.info(f"Downloading track {i}")
|
||||
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
track_path = os.path.join(temp_dir, f"track_{i}.webm")
|
||||
total_bytes = 0
|
||||
chunk_count = 0
|
||||
with open(track_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
chunk_count += 1
|
||||
if chunk_count % 12 == 0:
|
||||
now = time.time()
|
||||
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
|
||||
if cancelled.is_set():
|
||||
logger.info(f"Cancelled during track {i} download")
|
||||
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
|
||||
last_check = now
|
||||
|
||||
track_paths.append(track_path)
|
||||
logger.info(f"Track {i} downloaded: {total_bytes} bytes")
|
||||
|
||||
if not track_paths:
|
||||
raise ValueError("No tracks downloaded")
|
||||
|
||||
# --- Detect sample rate ---
|
||||
target_sample_rate = req.target_sample_rate
|
||||
if target_sample_rate is None:
|
||||
for path in track_paths:
|
||||
try:
|
||||
container = av.open(path)
|
||||
for frame in container.decode(audio=0):
|
||||
target_sample_rate = frame.sample_rate
|
||||
container.close()
|
||||
break
|
||||
else:
|
||||
container.close()
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if target_sample_rate is None:
|
||||
raise ValueError("Could not detect sample rate from any track")
|
||||
|
||||
logger.info(f"Target sample rate: {target_sample_rate}")
|
||||
|
||||
# --- Calculate per-input delays ---
|
||||
input_offsets_seconds = None
|
||||
if req.offsets_seconds is not None:
|
||||
input_offsets_seconds = [
|
||||
req.offsets_seconds[i] for i, url in enumerate(req.track_urls) if url
|
||||
]
|
||||
|
||||
delays_ms = []
|
||||
if input_offsets_seconds is not None:
|
||||
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
|
||||
delays_ms = [max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds]
|
||||
else:
|
||||
delays_ms = [0 for _ in track_paths]
|
||||
|
||||
# --- Build filter graph ---
|
||||
# N abuffer -> optional adelay -> amix -> aformat -> abuffersink
|
||||
graph = av.filter.Graph()
|
||||
inputs = []
|
||||
|
||||
for idx in range(len(track_paths)):
|
||||
args = (
|
||||
f"time_base=1/{target_sample_rate}:"
|
||||
f"sample_rate={target_sample_rate}:"
|
||||
f"sample_fmt=s32:"
|
||||
f"channel_layout=stereo"
|
||||
)
|
||||
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
|
||||
inputs.append(in_ctx)
|
||||
|
||||
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
|
||||
fmt = graph.add(
|
||||
"aformat",
|
||||
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
|
||||
name="fmt",
|
||||
)
|
||||
sink = graph.add("abuffersink", name="out")
|
||||
|
||||
for idx, in_ctx in enumerate(inputs):
|
||||
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
|
||||
if delay_ms > 0:
|
||||
adelay = graph.add(
|
||||
"adelay",
|
||||
args=f"delays={delay_ms}|{delay_ms}:all=1",
|
||||
name=f"delay{idx}",
|
||||
)
|
||||
in_ctx.link_to(adelay)
|
||||
adelay.link_to(mixer, 0, idx)
|
||||
else:
|
||||
in_ctx.link_to(mixer, 0, idx)
|
||||
|
||||
mixer.link_to(fmt)
|
||||
fmt.link_to(sink)
|
||||
graph.configure()
|
||||
|
||||
# --- Open all containers and decode ---
|
||||
containers = []
|
||||
output_path = os.path.join(temp_dir, "mixed.mp3")
|
||||
|
||||
try:
|
||||
for path in track_paths:
|
||||
containers.append(av.open(path))
|
||||
|
||||
decoders = [c.decode(audio=0) for c in containers]
|
||||
active = [True] * len(decoders)
|
||||
resamplers = [
|
||||
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
|
||||
for _ in decoders
|
||||
]
|
||||
|
||||
# Open output MP3
|
||||
out_container = av.open(output_path, "w", format="mp3")
|
||||
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
|
||||
total_duration = 0
|
||||
|
||||
while any(active):
|
||||
# Check cancellation periodically
|
||||
now = time.time()
|
||||
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
|
||||
if cancelled.is_set():
|
||||
logger.info("Cancelled during mixing")
|
||||
out_container.close()
|
||||
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
|
||||
last_check = now
|
||||
|
||||
for i, (dec, is_active) in enumerate(zip(decoders, active)):
|
||||
if not is_active:
|
||||
continue
|
||||
try:
|
||||
frame = next(dec)
|
||||
except StopIteration:
|
||||
active[i] = False
|
||||
inputs[i].push(None)
|
||||
continue
|
||||
|
||||
if frame.sample_rate != target_sample_rate:
|
||||
continue
|
||||
|
||||
out_frames = resamplers[i].resample(frame) or []
|
||||
for rf in out_frames:
|
||||
rf.sample_rate = target_sample_rate
|
||||
rf.time_base = Fraction(1, target_sample_rate)
|
||||
inputs[i].push(rf)
|
||||
|
||||
while True:
|
||||
try:
|
||||
mixed = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
mixed.sample_rate = target_sample_rate
|
||||
mixed.time_base = Fraction(1, target_sample_rate)
|
||||
for packet in out_stream.encode(mixed):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Flush filter graph
|
||||
while True:
|
||||
try:
|
||||
mixed = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
mixed.sample_rate = target_sample_rate
|
||||
mixed.time_base = Fraction(1, target_sample_rate)
|
||||
for packet in out_stream.encode(mixed):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Flush encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Calculate duration in ms
|
||||
last_tb = out_stream.time_base
|
||||
duration_ms = 0.0
|
||||
if last_tb and total_duration > 0:
|
||||
duration_ms = round(float(total_duration * last_tb * 1000), 2)
|
||||
|
||||
out_container.close()
|
||||
|
||||
finally:
|
||||
for c in containers:
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
logger.info(f"Mixdown complete: {file_size} bytes, {duration_ms}ms")
|
||||
|
||||
if cancelled.is_set():
|
||||
logger.info("Cancelled after mixing, before upload")
|
||||
return {"size": 0, "duration_ms": 0.0, "cancelled": True}
|
||||
|
||||
# --- Upload result ---
|
||||
logger.info("Uploading mixed audio to S3")
|
||||
with open(output_path, "rb") as f:
|
||||
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
|
||||
upload_response.raise_for_status()
|
||||
logger.info(f"Upload complete: {file_size} bytes")
|
||||
|
||||
return {"size": file_size, "duration_ms": duration_ms}
|
||||
|
||||
finally:
|
||||
# Cleanup all temp files
|
||||
for path in track_paths:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup track file: {e}")
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.unlink(output_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup output file: {e}")
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup temp directory: {e}")
|
||||
|
||||
return web_app
|
||||
@@ -52,10 +52,12 @@ OPUS_DEFAULT_BIT_RATE = 128000
|
||||
timeout=PADDING_TIMEOUT,
|
||||
scaledown_window=SCALEDOWN_WINDOW,
|
||||
image=image,
|
||||
secrets=[modal.Secret.from_name("reflector-gpu")],
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
class PaddingRequest(BaseModel):
|
||||
@@ -70,7 +72,18 @@ def web():
|
||||
|
||||
web_app = FastAPI()
|
||||
|
||||
@web_app.post("/pad")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
@web_app.post("/pad", dependencies=[Depends(apikey_auth)])
|
||||
async def pad_track_endpoint(request: Request, req: PaddingRequest) -> PaddingResponse:
|
||||
"""Modal web endpoint for padding audio tracks with disconnect detection.
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers.diarization import router as diarization_router
|
||||
from .routers.mixdown import router as mixdown_router
|
||||
from .routers.padding import router as padding_router
|
||||
from .routers.transcription import router as transcription_router
|
||||
from .routers.translation import router as translation_router
|
||||
@@ -29,4 +30,5 @@ def create_app() -> FastAPI:
|
||||
app.include_router(translation_router)
|
||||
app.include_router(diarization_router)
|
||||
app.include_router(padding_router)
|
||||
app.include_router(mixdown_router)
|
||||
return app
|
||||
|
||||
288
gpu/self_hosted/app/routers/mixdown.py
Normal file
288
gpu/self_hosted/app/routers/mixdown.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Audio mixdown endpoint for selfhosted GPU service.
|
||||
|
||||
CPU-intensive multi-track audio mixing service for combining N audio tracks
|
||||
into a single MP3 using PyAV amix filter graph.
|
||||
|
||||
IMPORTANT: This mixdown logic is duplicated from server/reflector/utils/audio_mixdown.py
|
||||
for deployment isolation (self_hosted can't import from server/reflector/). If you modify
|
||||
the PyAV filter graph or mixdown algorithm, you MUST update both:
|
||||
- gpu/self_hosted/app/routers/mixdown.py (this file)
|
||||
- server/reflector/utils/audio_mixdown.py
|
||||
|
||||
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
import requests
|
||||
from av.audio.resampler import AudioResampler
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["mixdown"])
|
||||
|
||||
S3_TIMEOUT = 120
|
||||
|
||||
|
||||
class MixdownRequest(BaseModel):
|
||||
track_urls: list[str]
|
||||
output_url: str
|
||||
target_sample_rate: int | None = None
|
||||
offsets_seconds: list[float] | None = None
|
||||
|
||||
|
||||
class MixdownResponse(BaseModel):
|
||||
size: int
|
||||
duration_ms: float = 0.0
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
@router.post("/mixdown", dependencies=[Depends(apikey_auth)], response_model=MixdownResponse)
|
||||
def mixdown_tracks(req: MixdownRequest):
|
||||
"""Mix multiple audio tracks into single MP3 using PyAV amix filter graph."""
|
||||
valid_urls = [u for u in req.track_urls if u]
|
||||
if not valid_urls:
|
||||
raise HTTPException(status_code=400, detail="No valid track URLs provided")
|
||||
if req.offsets_seconds is not None:
|
||||
if len(req.offsets_seconds) != len(req.track_urls):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"offsets_seconds length ({len(req.offsets_seconds)}) "
|
||||
f"must match track_urls ({len(req.track_urls)})",
|
||||
)
|
||||
if any(o > 18000 for o in req.offsets_seconds):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="offsets_seconds exceeds maximum 18000s (5 hours)"
|
||||
)
|
||||
if not req.output_url:
|
||||
raise HTTPException(status_code=400, detail="output_url cannot be empty")
|
||||
|
||||
logger.info("Mixdown request: %d tracks", len(valid_urls))
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
track_paths = []
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
# --- Download all tracks ---
|
||||
for i, url in enumerate(valid_urls):
|
||||
logger.info("Downloading track %d", i)
|
||||
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
track_path = os.path.join(temp_dir, f"track_{i}.webm")
|
||||
total_bytes = 0
|
||||
with open(track_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
|
||||
track_paths.append(track_path)
|
||||
logger.info("Track %d downloaded: %d bytes", i, total_bytes)
|
||||
|
||||
if not track_paths:
|
||||
raise HTTPException(status_code=400, detail="No tracks could be downloaded")
|
||||
|
||||
# --- Detect sample rate ---
|
||||
target_sample_rate = req.target_sample_rate
|
||||
if target_sample_rate is None:
|
||||
for path in track_paths:
|
||||
try:
|
||||
container = av.open(path)
|
||||
for frame in container.decode(audio=0):
|
||||
target_sample_rate = frame.sample_rate
|
||||
container.close()
|
||||
break
|
||||
else:
|
||||
container.close()
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if target_sample_rate is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Could not detect sample rate from any track"
|
||||
)
|
||||
|
||||
logger.info("Target sample rate: %d", target_sample_rate)
|
||||
|
||||
# --- Calculate per-input delays ---
|
||||
input_offsets_seconds = None
|
||||
if req.offsets_seconds is not None:
|
||||
input_offsets_seconds = [
|
||||
req.offsets_seconds[i] for i, url in enumerate(req.track_urls) if url
|
||||
]
|
||||
|
||||
delays_ms = []
|
||||
if input_offsets_seconds is not None:
|
||||
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
|
||||
delays_ms = [max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds]
|
||||
else:
|
||||
delays_ms = [0 for _ in track_paths]
|
||||
|
||||
# --- Build filter graph ---
|
||||
# N abuffer -> optional adelay -> amix -> aformat -> abuffersink
|
||||
graph = av.filter.Graph()
|
||||
inputs = []
|
||||
|
||||
for idx in range(len(track_paths)):
|
||||
args = (
|
||||
f"time_base=1/{target_sample_rate}:"
|
||||
f"sample_rate={target_sample_rate}:"
|
||||
f"sample_fmt=s32:"
|
||||
f"channel_layout=stereo"
|
||||
)
|
||||
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
|
||||
inputs.append(in_ctx)
|
||||
|
||||
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
|
||||
fmt = graph.add(
|
||||
"aformat",
|
||||
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
|
||||
name="fmt",
|
||||
)
|
||||
sink = graph.add("abuffersink", name="out")
|
||||
|
||||
for idx, in_ctx in enumerate(inputs):
|
||||
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
|
||||
if delay_ms > 0:
|
||||
adelay = graph.add(
|
||||
"adelay",
|
||||
args=f"delays={delay_ms}|{delay_ms}:all=1",
|
||||
name=f"delay{idx}",
|
||||
)
|
||||
in_ctx.link_to(adelay)
|
||||
adelay.link_to(mixer, 0, idx)
|
||||
else:
|
||||
in_ctx.link_to(mixer, 0, idx)
|
||||
|
||||
mixer.link_to(fmt)
|
||||
fmt.link_to(sink)
|
||||
graph.configure()
|
||||
|
||||
# --- Open all containers and decode ---
|
||||
containers = []
|
||||
output_path = os.path.join(temp_dir, "mixed.mp3")
|
||||
|
||||
try:
|
||||
for path in track_paths:
|
||||
containers.append(av.open(path))
|
||||
|
||||
decoders = [c.decode(audio=0) for c in containers]
|
||||
active = [True] * len(decoders)
|
||||
resamplers = [
|
||||
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
|
||||
for _ in decoders
|
||||
]
|
||||
|
||||
# Open output MP3
|
||||
out_container = av.open(output_path, "w", format="mp3")
|
||||
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
|
||||
total_duration = 0
|
||||
|
||||
while any(active):
|
||||
for i, (dec, is_active) in enumerate(zip(decoders, active)):
|
||||
if not is_active:
|
||||
continue
|
||||
try:
|
||||
frame = next(dec)
|
||||
except StopIteration:
|
||||
active[i] = False
|
||||
inputs[i].push(None)
|
||||
continue
|
||||
|
||||
if frame.sample_rate != target_sample_rate:
|
||||
continue
|
||||
|
||||
out_frames = resamplers[i].resample(frame) or []
|
||||
for rf in out_frames:
|
||||
rf.sample_rate = target_sample_rate
|
||||
rf.time_base = Fraction(1, target_sample_rate)
|
||||
inputs[i].push(rf)
|
||||
|
||||
while True:
|
||||
try:
|
||||
mixed = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
mixed.sample_rate = target_sample_rate
|
||||
mixed.time_base = Fraction(1, target_sample_rate)
|
||||
for packet in out_stream.encode(mixed):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Flush filter graph
|
||||
while True:
|
||||
try:
|
||||
mixed = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
mixed.sample_rate = target_sample_rate
|
||||
mixed.time_base = Fraction(1, target_sample_rate)
|
||||
for packet in out_stream.encode(mixed):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Flush encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Calculate duration in ms
|
||||
last_tb = out_stream.time_base
|
||||
duration_ms = 0.0
|
||||
if last_tb and total_duration > 0:
|
||||
duration_ms = round(float(total_duration * last_tb * 1000), 2)
|
||||
|
||||
out_container.close()
|
||||
|
||||
finally:
|
||||
for c in containers:
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
logger.info("Mixdown complete: %d bytes, %.2fms", file_size, duration_ms)
|
||||
|
||||
# --- Upload result ---
|
||||
logger.info("Uploading mixed audio to S3")
|
||||
with open(output_path, "rb") as f:
|
||||
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
|
||||
upload_response.raise_for_status()
|
||||
logger.info("Upload complete: %d bytes", file_size)
|
||||
|
||||
return MixdownResponse(size=file_size, duration_ms=duration_ms)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Mixdown failed: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Mixdown failed: {e}") from e
|
||||
finally:
|
||||
for path in track_paths:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup track file: %s", e)
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.unlink(output_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup output file: %s", e)
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup temp directory: %s", e)
|
||||
Reference in New Issue
Block a user