mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
feat: modal padding (#837)
* Add Modal backend for audio padding - Create reflector_padding.py Modal deployment (CPU-based) - Add PaddingWorkflow with conditional Modal/local backend - Update deploy-all.sh to include padding deployment --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
277
gpu/modal_deployments/reflector_padding.py
Normal file
277
gpu/modal_deployments/reflector_padding.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Reflector GPU backend - audio padding
|
||||
======================================
|
||||
|
||||
CPU-intensive audio padding service for adding silence to audio tracks.
|
||||
Uses PyAV filter graph (adelay) for precise track synchronization.
|
||||
|
||||
IMPORTANT: This padding logic is duplicated from server/reflector/utils/audio_padding.py
|
||||
for Modal deployment isolation (Modal can't import from server/reflector/). If you modify
|
||||
the PyAV filter graph or padding algorithm, you MUST update both:
|
||||
- gpu/modal_deployments/reflector_padding.py (this file)
|
||||
- server/reflector/utils/audio_padding.py
|
||||
|
||||
Constants duplicated from server/reflector/utils/audio_constants.py for same reason.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from fractions import Fraction
|
||||
import math
|
||||
import asyncio
|
||||
|
||||
import modal
|
||||
|
||||
S3_TIMEOUT = 60 # happens 2 times
|
||||
PADDING_TIMEOUT = 600 + (S3_TIMEOUT * 2)
|
||||
SCALEDOWN_WINDOW = 60 # The maximum duration (in seconds) that individual containers can remain idle when scaling down.
|
||||
DISCONNECT_CHECK_INTERVAL = 2 # Check for client disconnect
|
||||
|
||||
|
||||
app = modal.App("reflector-padding")
|
||||
|
||||
# CPU-based image
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.12")
|
||||
.apt_install("ffmpeg") # Required by PyAV
|
||||
.pip_install(
|
||||
"av==13.1.0", # PyAV for audio processing
|
||||
"requests==2.32.3", # HTTP for presigned URL downloads/uploads
|
||||
"fastapi==0.115.12", # API framework
|
||||
)
|
||||
)
|
||||
|
||||
# ref B0F71CE8-FC59-4AA5-8414-DAFB836DB711
|
||||
OPUS_STANDARD_SAMPLE_RATE = 48000
|
||||
# ref B0F71CE8-FC59-4AA5-8414-DAFB836DB711
|
||||
OPUS_DEFAULT_BIT_RATE = 128000
|
||||
|
||||
|
||||
@app.function(
|
||||
cpu=2.0,
|
||||
timeout=PADDING_TIMEOUT,
|
||||
scaledown_window=SCALEDOWN_WINDOW,
|
||||
image=image,
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
class PaddingRequest(BaseModel):
|
||||
track_url: str
|
||||
output_url: str
|
||||
start_time_seconds: float
|
||||
track_index: int
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
web_app = FastAPI()
|
||||
|
||||
@web_app.post("/pad")
|
||||
async def pad_track_endpoint(request: Request, req: PaddingRequest) -> PaddingResponse:
|
||||
"""Modal web endpoint for padding audio tracks with disconnect detection.
|
||||
"""
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not req.track_url:
|
||||
raise HTTPException(status_code=400, detail="track_url cannot be empty")
|
||||
if not req.output_url:
|
||||
raise HTTPException(status_code=400, detail="output_url cannot be empty")
|
||||
if req.start_time_seconds <= 0:
|
||||
raise HTTPException(status_code=400, detail=f"start_time_seconds must be positive, got {req.start_time_seconds}")
|
||||
if req.start_time_seconds > 18000:
|
||||
raise HTTPException(status_code=400, detail=f"start_time_seconds exceeds maximum 18000s (5 hours)")
|
||||
|
||||
logger.info(f"Padding request: track {req.track_index}, delay={req.start_time_seconds}s")
|
||||
|
||||
# Thread-safe cancellation flag shared between async disconnect checker and blocking thread
|
||||
import threading
|
||||
cancelled = threading.Event()
|
||||
|
||||
async def check_disconnect():
|
||||
"""Background task to check for client disconnect every 2 seconds."""
|
||||
while not cancelled.is_set():
|
||||
await asyncio.sleep(DISCONNECT_CHECK_INTERVAL)
|
||||
if await request.is_disconnected():
|
||||
logger.warning("Client disconnected, setting cancellation flag")
|
||||
cancelled.set()
|
||||
break
|
||||
|
||||
# Start disconnect checker in background
|
||||
disconnect_task = asyncio.create_task(check_disconnect())
|
||||
|
||||
try:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None, _pad_track_blocking, req, cancelled, logger
|
||||
)
|
||||
return PaddingResponse(**result)
|
||||
finally:
|
||||
cancelled.set()
|
||||
disconnect_task.cancel()
|
||||
try:
|
||||
await disconnect_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _pad_track_blocking(req, cancelled, logger) -> dict:
|
||||
"""Blocking CPU-bound padding work with periodic cancellation checks.
|
||||
|
||||
Args:
|
||||
cancelled: threading.Event for thread-safe cancellation signaling
|
||||
"""
|
||||
import av
|
||||
import requests
|
||||
from av.audio.resampler import AudioResampler
|
||||
import time
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
input_path = None
|
||||
output_path = None
|
||||
last_check = time.time()
|
||||
|
||||
try:
|
||||
logger.info("Downloading track for padding")
|
||||
response = requests.get(req.track_url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
input_path = os.path.join(temp_dir, "track.webm")
|
||||
total_bytes = 0
|
||||
chunk_count = 0
|
||||
with open(input_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
chunk_count += 1
|
||||
|
||||
# Check for cancellation every arbitrary amount of chunks
|
||||
if chunk_count % 12 == 0:
|
||||
now = time.time()
|
||||
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
|
||||
if cancelled.is_set():
|
||||
logger.info("Cancelled during download, exiting early")
|
||||
return {"size": 0, "cancelled": True}
|
||||
last_check = now
|
||||
logger.info(f"Track downloaded: {total_bytes} bytes")
|
||||
|
||||
if cancelled.is_set():
|
||||
logger.info("Cancelled after download, exiting early")
|
||||
return {"size": 0, "cancelled": True}
|
||||
|
||||
# Apply padding using PyAV
|
||||
output_path = os.path.join(temp_dir, "padded.webm")
|
||||
delay_ms = math.floor(req.start_time_seconds * 1000)
|
||||
logger.info(f"Padding track {req.track_index} with {delay_ms}ms delay using PyAV")
|
||||
|
||||
in_container = av.open(input_path)
|
||||
in_stream = next((s for s in in_container.streams if s.type == "audio"), None)
|
||||
if in_stream is None:
|
||||
raise ValueError("No audio stream in input")
|
||||
|
||||
with av.open(output_path, "w", format="webm") as out_container:
|
||||
out_stream = out_container.add_stream("libopus", rate=OPUS_STANDARD_SAMPLE_RATE)
|
||||
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
|
||||
graph = av.filter.Graph()
|
||||
|
||||
abuf_args = (
|
||||
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
|
||||
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
|
||||
f"sample_fmt=s16:"
|
||||
f"channel_layout=stereo"
|
||||
)
|
||||
src = graph.add("abuffer", args=abuf_args, name="src")
|
||||
aresample_f = graph.add("aresample", args="async=1", name="ares")
|
||||
delays_arg = f"{delay_ms}|{delay_ms}"
|
||||
adelay_f = graph.add("adelay", args=f"delays={delays_arg}:all=1", name="delay")
|
||||
sink = graph.add("abuffersink", name="sink")
|
||||
|
||||
src.link_to(aresample_f)
|
||||
aresample_f.link_to(adelay_f)
|
||||
adelay_f.link_to(sink)
|
||||
graph.configure()
|
||||
|
||||
resampler = AudioResampler(
|
||||
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
|
||||
)
|
||||
|
||||
for frame in in_container.decode(in_stream):
|
||||
# Check for cancellation periodically
|
||||
now = time.time()
|
||||
if now - last_check >= DISCONNECT_CHECK_INTERVAL:
|
||||
if cancelled.is_set():
|
||||
logger.info("Cancelled during processing, exiting early")
|
||||
in_container.close()
|
||||
return {"size": 0, "cancelled": True}
|
||||
last_check = now
|
||||
|
||||
out_frames = resampler.resample(frame) or []
|
||||
for rframe in out_frames:
|
||||
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
src.push(rframe)
|
||||
|
||||
while True:
|
||||
try:
|
||||
f_out = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
for packet in out_stream.encode(f_out):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush filter graph
|
||||
src.push(None)
|
||||
while True:
|
||||
try:
|
||||
f_out = sink.pull()
|
||||
except Exception:
|
||||
break
|
||||
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
|
||||
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
|
||||
for packet in out_stream.encode(f_out):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
|
||||
in_container.close()
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
logger.info(f"Padding complete: {file_size} bytes")
|
||||
|
||||
logger.info("Uploading padded track to S3")
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
upload_response = requests.put(req.output_url, data=f, timeout=S3_TIMEOUT)
|
||||
|
||||
upload_response.raise_for_status()
|
||||
logger.info(f"Upload complete: {file_size} bytes")
|
||||
|
||||
return {"size": file_size}
|
||||
|
||||
finally:
|
||||
if input_path and os.path.exists(input_path):
|
||||
try:
|
||||
os.unlink(input_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup input file: {e}")
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.unlink(output_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup output file: {e}")
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup temp directory: {e}")
|
||||
|
||||
return web_app
|
||||
|
||||
Reference in New Issue
Block a user