mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
* Add Modal backend for audio padding - Create reflector_padding.py Modal deployment (CPU-based) - Add PaddingWorkflow with conditional Modal/local backend - Update deploy-all.sh to include padding deployment --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
278 lines
11 KiB
Python
278 lines
11 KiB
Python
"""
|
|
Reflector GPU backend - audio padding
|
|
======================================
|
|
|
|
CPU-intensive audio padding service for adding silence to audio tracks.
|
|
Uses PyAV filter graph (adelay) for precise track synchronization.
|
|
|
|
IMPORTANT: This padding logic is duplicated from server/reflector/utils/audio_padding.py
|
|
for Modal deployment isolation (Modal can't import from server/reflector/). If you modify
|
|
the PyAV filter graph or padding algorithm, you MUST update both:
|
|
- gpu/modal_deployments/reflector_padding.py (this file)
|
|
- server/reflector/utils/audio_padding.py
|
|
|
|
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
from fractions import Fraction
|
|
import math
|
|
import asyncio
|
|
|
|
import modal
|
|
|
|
S3_TIMEOUT = 60 # happens 2 times
|
|
PADDING_TIMEOUT = 600 + (S3_TIMEOUT * 2)
|
|
SCALEDOWN_WINDOW = 60 # The maximum duration (in seconds) that individual containers can remain idle when scaling down.
|
|
DISCONNECT_CHECK_INTERVAL = 2 # Check for client disconnect
|
|
|
|
|
|
app = modal.App("reflector-padding")
|
|
|
|
# CPU-based image
|
|
image = (
|
|
modal.Image.debian_slim(python_version="3.12")
|
|
.apt_install("ffmpeg") # Required by PyAV
|
|
.pip_install(
|
|
"av==13.1.0", # PyAV for audio processing
|
|
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
|
|
"fastapi==0.115.12", # API framework
|
|
)
|
|
)
|
|
|
|
# ref B0F71CE8-FC59-4AA5-8414-DAFB836DB711
|
|
OPUS_STANDARD_SAMPLE_RATE = 48000
|
|
# ref B0F71CE8-FC59-4AA5-8414-DAFB836DB711
|
|
OPUS_DEFAULT_BIT_RATE = 128000
|
|
|
|
|
|
@app.function(
|
|
cpu=2.0,
|
|
timeout=PADDING_TIMEOUT,
|
|
scaledown_window=SCALEDOWN_WINDOW,
|
|
image=image,
|
|
)
|
|
@modal.asgi_app()
|
|
def web():
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
class PaddingRequest(BaseModel):
|
|
track_url: str
|
|
output_url: str
|
|
start_time_seconds: float
|
|
track_index: int
|
|
|
|
class PaddingResponse(BaseModel):
|
|
size: int
|
|
cancelled: bool = False
|
|
|
|
web_app = FastAPI()
|
|
|
|
@web_app.post("/pad")
|
|
async def pad_track_endpoint(request: Request, req: PaddingRequest) -> PaddingResponse:
|
|
"""Modal web endpoint for padding audio tracks with disconnect detection.
|
|
"""
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if not req.track_url:
|
|
raise HTTPException(status_code=400, detail="track_url cannot be empty")
|
|
if not req.output_url:
|
|
raise HTTPException(status_code=400, detail="output_url cannot be empty")
|
|
if req.start_time_seconds <= 0:
|
|
raise HTTPException(status_code=400, detail=f"start_time_seconds must be positive, got {req.start_time_seconds}")
|
|
if req.start_time_seconds > 18000:
|
|
raise HTTPException(status_code=400, detail=f"start_time_seconds exceeds maximum 18000s (5 hours)")
|
|
|
|
logger.info(f"Padding request: track {req.track_index}, delay={req.start_time_seconds}s")
|
|
|
|
# Thread-safe cancellation flag shared between async disconnect checker and blocking thread
|
|
import threading
|
|
cancelled = threading.Event()
|
|
|
|
async def check_disconnect():
|
|
"""Background task to check for client disconnect every 2 seconds."""
|
|
while not cancelled.is_set():
|
|
await asyncio.sleep(DISCONNECT_CHECK_INTERVAL)
|
|
if await request.is_disconnected():
|
|
logger.warning("Client disconnected, setting cancellation flag")
|
|
cancelled.set()
|
|
break
|
|
|
|
# Start disconnect checker in background
|
|
disconnect_task = asyncio.create_task(check_disconnect())
|
|
|
|
try:
|
|
result = await asyncio.get_event_loop().run_in_executor(
|
|
None, _pad_track_blocking, req, cancelled, logger
|
|
)
|
|
return PaddingResponse(**result)
|
|
finally:
|
|
cancelled.set()
|
|
disconnect_task.cancel()
|
|
try:
|
|
await disconnect_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
def _pad_track_blocking(req, cancelled, logger) -> dict:
|
|
"""Blocking CPU-bound padding work with periodic cancellation checks.
|
|
|
|
Args:
|
|
cancelled: threading.Event for thread-safe cancellation signaling
|
|
"""
|
|
import av
|
|
import requests
|
|
from av.audio.resampler import AudioResampler
|
|
import time
|
|
|
|
temp_dir = tempfile.mkdtemp()
|
|
input_path = None
|
|
output_path = None
|
|
last_check = time.time()
|
|
|
|
try:
|
|
logger.info("Downloading track for padding")
|
|
response = requests.get(req.track_url, stream=True, timeout=S3_TIMEOUT)
|
|
response.raise_for_status()
|
|
|
|
input_path = os.path.join(temp_dir, "track.webm")
|
|
total_bytes = 0
|
|
chunk_count = 0
|
|
with open(input_path, "wb") as f:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
f.write(chunk)
|
|
total_bytes += len(chunk)
|
|
chunk_count += 1
|
|
|
|
# Check for cancellation every arbitrary amount of chunks
|
|
if chunk_count % 12 == 0:
|
|
now = time.time()
|
|
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
|
|
if cancelled.is_set():
|
|
logger.info("Cancelled during download, exiting early")
|
|
return {"size": 0, "cancelled": True}
|
|
last_check = now
|
|
logger.info(f"Track downloaded: {total_bytes} bytes")
|
|
|
|
if cancelled.is_set():
|
|
logger.info("Cancelled after download, exiting early")
|
|
return {"size": 0, "cancelled": True}
|
|
|
|
# Apply padding using PyAV
|
|
output_path = os.path.join(temp_dir, "padded.webm")
|
|
delay_ms = math.floor(req.start_time_seconds * 1000)
|
|
logger.info(f"Padding track {req.track_index} with {delay_ms}ms delay using PyAV")
|
|
|
|
in_container = av.open(input_path)
|
|
in_stream = next((s for s in in_container.streams if s.type == "audio"), None)
|
|
if in_stream is None:
|
|
raise ValueError("No audio stream in input")
|
|
|
|
with av.open(output_path, "w", format="webm") as out_container:
|
|
out_stream = out_container.add_stream("libopus", rate=OPUS_STANDARD_SAMPLE_RATE)
|
|
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
|
|
graph = av.filter.Graph()
|
|
|
|
abuf_args = (
|
|
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
|
|
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
|
|
f"sample_fmt=s16:"
|
|
f"channel_layout=stereo"
|
|
)
|
|
src = graph.add("abuffer", args=abuf_args, name="src")
|
|
aresample_f = graph.add("aresample", args="async=1", name="ares")
|
|
delays_arg = f"{delay_ms}|{delay_ms}"
|
|
adelay_f = graph.add("adelay", args=f"delays={delays_arg}:all=1", name="delay")
|
|
sink = graph.add("abuffersink", name="sink")
|
|
|
|
src.link_to(aresample_f)
|
|
aresample_f.link_to(adelay_f)
|
|
adelay_f.link_to(sink)
|
|
graph.configure()
|
|
|
|
resampler = AudioResampler(
|
|
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
|
|
)
|
|
|
|
for frame in in_container.decode(in_stream):
|
|
# Check for cancellation periodically
|
|
now = time.time()
|
|
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
|
|
if cancelled.is_set():
|
|
logger.info("Cancelled during processing, exiting early")
|
|
in_container.close()
|
|
return {"size": 0, "cancelled": True}
|
|
last_check = now
|
|
|
|
out_frames = resampler.resample(frame) or []
|
|
for rframe in out_frames:
|
|
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
|
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
|
src.push(rframe)
|
|
|
|
while True:
|
|
try:
|
|
f_out = sink.pull()
|
|
except Exception:
|
|
break
|
|
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
|
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
|
for packet in out_stream.encode(f_out):
|
|
out_container.mux(packet)
|
|
|
|
# Flush filter graph
|
|
src.push(None)
|
|
while True:
|
|
try:
|
|
f_out = sink.pull()
|
|
except Exception:
|
|
break
|
|
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
|
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
|
for packet in out_stream.encode(f_out):
|
|
out_container.mux(packet)
|
|
|
|
# Flush encoder
|
|
for packet in out_stream.encode(None):
|
|
out_container.mux(packet)
|
|
|
|
in_container.close()
|
|
|
|
file_size = os.path.getsize(output_path)
|
|
logger.info(f"Padding complete: {file_size} bytes")
|
|
|
|
logger.info("Uploading padded track to S3")
|
|
|
|
with open(output_path, "rb") as f:
|
|
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
|
|
|
|
upload_response.raise_for_status()
|
|
logger.info(f"Upload complete: {file_size} bytes")
|
|
|
|
return {"size": file_size}
|
|
|
|
finally:
|
|
if input_path and os.path.exists(input_path):
|
|
try:
|
|
os.unlink(input_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup input file: {e}")
|
|
if output_path and os.path.exists(output_path):
|
|
try:
|
|
os.unlink(output_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup output file: {e}")
|
|
try:
|
|
os.rmdir(temp_dir)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup temp directory: {e}")
|
|
|
|
return web_app
|
|
|