mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-08 14:56:49 +00:00
feat: mixdown modal services + processor pattern (#936)
* allow memory flags and per service config * feat: mixdown modal services + processor pattern
This commit is contained in:
committed by
GitHub
parent
12bf0c2d77
commit
d164e486cc
@@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers.diarization import router as diarization_router
|
||||
from .routers.mixdown import router as mixdown_router
|
||||
from .routers.padding import router as padding_router
|
||||
from .routers.transcription import router as transcription_router
|
||||
from .routers.translation import router as translation_router
|
||||
@@ -29,4 +30,5 @@ def create_app() -> FastAPI:
|
||||
app.include_router(translation_router)
|
||||
app.include_router(diarization_router)
|
||||
app.include_router(padding_router)
|
||||
app.include_router(mixdown_router)
|
||||
return app
|
||||
|
||||
288
gpu/self_hosted/app/routers/mixdown.py
Normal file
288
gpu/self_hosted/app/routers/mixdown.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Audio mixdown endpoint for selfhosted GPU service.
|
||||
|
||||
CPU-intensive multi-track audio mixing service for combining N audio tracks
|
||||
into a single MP3 using PyAV amix filter graph.
|
||||
|
||||
IMPORTANT: This mixdown logic is duplicated from server/reflector/utils/audio_mixdown.py
|
||||
for deployment isolation (self_hosted can't import from server/reflector/). If you modify
|
||||
the PyAV filter graph or mixdown algorithm, you MUST update both:
|
||||
- gpu/self_hosted/app/routers/mixdown.py (this file)
|
||||
- server/reflector/utils/audio_mixdown.py
|
||||
|
||||
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
import requests
|
||||
from av.audio.resampler import AudioResampler
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import apikey_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["mixdown"])
|
||||
|
||||
S3_TIMEOUT = 120
|
||||
|
||||
|
||||
class MixdownRequest(BaseModel):
|
||||
track_urls: list[str]
|
||||
output_url: str
|
||||
target_sample_rate: int | None = None
|
||||
offsets_seconds: list[float] | None = None
|
||||
|
||||
|
||||
class MixdownResponse(BaseModel):
|
||||
size: int
|
||||
duration_ms: float = 0.0
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
@router.post("/mixdown", dependencies=[Depends(apikey_auth)], response_model=MixdownResponse)
|
||||
def mixdown_tracks(req: MixdownRequest):
|
||||
"""Mix multiple audio tracks into single MP3 using PyAV amix filter graph."""
|
||||
valid_urls = [u for u in req.track_urls if u]
|
||||
if not valid_urls:
|
||||
raise HTTPException(status_code=400, detail="No valid track URLs provided")
|
||||
if req.offsets_seconds is not None:
|
||||
if len(req.offsets_seconds) != len(req.track_urls):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"offsets_seconds length ({len(req.offsets_seconds)}) "
|
||||
f"must match track_urls ({len(req.track_urls)})",
|
||||
)
|
||||
if any(o > 18000 for o in req.offsets_seconds):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="offsets_seconds exceeds maximum 18000s (5 hours)"
|
||||
)
|
||||
if not req.output_url:
|
||||
raise HTTPException(status_code=400, detail="output_url cannot be empty")
|
||||
|
||||
logger.info("Mixdown request: %d tracks", len(valid_urls))
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
track_paths = []
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
# --- Download all tracks ---
|
||||
for i, url in enumerate(valid_urls):
|
||||
logger.info("Downloading track %d", i)
|
||||
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
track_path = os.path.join(temp_dir, f"track_{i}.webm")
|
||||
total_bytes = 0
|
||||
with open(track_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
|
||||
track_paths.append(track_path)
|
||||
logger.info("Track %d downloaded: %d bytes", i, total_bytes)
|
||||
|
||||
if not track_paths:
|
||||
raise HTTPException(status_code=400, detail="No tracks could be downloaded")
|
||||
|
||||
# --- Detect sample rate ---
|
||||
target_sample_rate = req.target_sample_rate
|
||||
if target_sample_rate is None:
|
||||
for path in track_paths:
|
||||
try:
|
||||
container = av.open(path)
|
||||
for frame in container.decode(audio=0):
|
||||
target_sample_rate = frame.sample_rate
|
||||
container.close()
|
||||
break
|
||||
else:
|
||||
container.close()
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if target_sample_rate is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Could not detect sample rate from any track"
|
||||
)
|
||||
|
||||
logger.info("Target sample rate: %d", target_sample_rate)
|
||||
|
||||
# --- Calculate per-input delays ---
|
||||
input_offsets_seconds = None
|
||||
if req.offsets_seconds is not None:
|
||||
input_offsets_seconds = [
|
||||
req.offsets_seconds[i] for i, url in enumerate(req.track_urls) if url
|
||||
]
|
||||
|
||||
delays_ms = []
|
||||
if input_offsets_seconds is not None:
|
||||
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
|
||||
delays_ms = [max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds]
|
||||
else:
|
||||
delays_ms = [0 for _ in track_paths]
|
||||
|
||||
# --- Build filter graph ---
|
||||
# N abuffer -> optional adelay -> amix -> aformat -> abuffersink
|
||||
graph = av.filter.Graph()
|
||||
inputs = []
|
||||
|
||||
for idx in range(len(track_paths)):
|
||||
args = (
|
||||
f"time_base=1/{target_sample_rate}:"
|
||||
f"sample_rate={target_sample_rate}:"
|
||||
f"sample_fmt=s32:"
|
||||
f"channel_layout=stereo"
|
||||
)
|
||||
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
|
||||
inputs.append(in_ctx)
|
||||
|
||||
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
|
||||
fmt = graph.add(
|
||||
"aformat",
|
||||
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
|
||||
name="fmt",
|
||||
)
|
||||
sink = graph.add("abuffersink", name="out")
|
||||
|
||||
for idx, in_ctx in enumerate(inputs):
|
||||
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
|
||||
if delay_ms > 0:
|
||||
adelay = graph.add(
|
||||
"adelay",
|
||||
args=f"delays={delay_ms}|{delay_ms}:all=1",
|
||||
name=f"delay{idx}",
|
||||
)
|
||||
in_ctx.link_to(adelay)
|
||||
adelay.link_to(mixer, 0, idx)
|
||||
else:
|
||||
in_ctx.link_to(mixer, 0, idx)
|
||||
|
||||
mixer.link_to(fmt)
|
||||
fmt.link_to(sink)
|
||||
graph.configure()
|
||||
|
||||
# --- Open all containers and decode ---
|
||||
containers = []
|
||||
output_path = os.path.join(temp_dir, "mixed.mp3")
|
||||
|
||||
try:
|
||||
for path in track_paths:
|
||||
containers.append(av.open(path))
|
||||
|
||||
decoders = [c.decode(audio=0) for c in containers]
|
||||
active = [True] * len(decoders)
|
||||
resamplers = [
|
||||
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
|
||||
for _ in decoders
|
||||
]
|
||||
|
||||
# Open output MP3
|
||||
out_container = av.open(output_path, "w", format="mp3")
|
||||
out_stream = out_container.add_stream("libmp3lame", rate=target_sample_rate)
|
||||
total_duration = 0
|
||||
|
||||
while any(active):
|
||||
for i, (dec, is_active) in enumerate(zip(decoders, active)):
|
||||
if not is_active:
|
||||
continue
|
||||
try:
|
||||
frame = next(dec)
|
||||
except StopIteration:
|
||||
active[i] = False
|
||||
inputs[i].push(None)
|
||||
continue
|
||||
|
||||
if frame.sample_rate != target_sample_rate:
|
||||
continue
|
||||
|
||||
out_frames = resamplers[i].resample(frame) or []
|
||||
for rf in out_frames:
|
||||
rf.sample_rate = target_sample_rate
|
||||
rf.time_base = Fraction(1, target_sample_rate)
|
||||
inputs[i].push(rf)
|
||||
|
||||
while True:
|
||||
try:
|
||||
mixed = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
mixed.sample_rate = target_sample_rate
|
||||
mixed.time_base = Fraction(1, target_sample_rate)
|
||||
for packet in out_stream.encode(mixed):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Flush filter graph
|
||||
while True:
|
||||
try:
|
||||
mixed = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
mixed.sample_rate = target_sample_rate
|
||||
mixed.time_base = Fraction(1, target_sample_rate)
|
||||
for packet in out_stream.encode(mixed):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Flush encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
total_duration += packet.duration
|
||||
|
||||
# Calculate duration in ms
|
||||
last_tb = out_stream.time_base
|
||||
duration_ms = 0.0
|
||||
if last_tb and total_duration > 0:
|
||||
duration_ms = round(float(total_duration * last_tb * 1000), 2)
|
||||
|
||||
out_container.close()
|
||||
|
||||
finally:
|
||||
for c in containers:
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
logger.info("Mixdown complete: %d bytes, %.2fms", file_size, duration_ms)
|
||||
|
||||
# --- Upload result ---
|
||||
logger.info("Uploading mixed audio to S3")
|
||||
with open(output_path, "rb") as f:
|
||||
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
|
||||
upload_response.raise_for_status()
|
||||
logger.info("Upload complete: %d bytes", file_size)
|
||||
|
||||
return MixdownResponse(size=file_size, duration_ms=duration_ms)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Mixdown failed: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Mixdown failed: {e}") from e
|
||||
finally:
|
||||
for path in track_paths:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup track file: %s", e)
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.unlink(output_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup output file: %s", e)
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup temp directory: %s", e)
|
||||
Reference in New Issue
Block a user