dry hatchet with celery

This commit is contained in:
Igor Loskutov
2025-12-17 14:48:23 -05:00
parent e77f38a12a
commit d683a83906
3 changed files with 77 additions and 603 deletions

View File

@@ -14,13 +14,10 @@ import functools
import tempfile
from contextlib import asynccontextmanager
from datetime import timedelta
from fractions import Fraction
from pathlib import Path
from typing import Callable
import av
import httpx
from av.audio.resampler import AudioResampler
from hatchet_sdk import Context
from pydantic import BaseModel
@@ -30,6 +27,7 @@ from reflector.hatchet.broadcast import (
set_status_and_broadcast,
)
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.utils import to_dict
from reflector.hatchet.workflows.models import (
ConsentResult,
FinalizeResult,
@@ -62,6 +60,10 @@ from reflector.utils.audio_constants import (
PRESIGNED_URL_EXPIRATION_SECONDS,
WAVEFORM_SEGMENTS,
)
from reflector.utils.audio_mixdown import (
detect_sample_rate_from_tracks,
mixdown_tracks_pyav,
)
from reflector.utils.audio_waveform import get_audio_waveform
from reflector.utils.daily import (
filter_cam_audio_tracks,
@@ -146,17 +148,6 @@ def _get_storage():
)
def _to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns.
Hatchet SDK returns Pydantic models when tasks have typed return annotations,
but older code expects dicts. This helper normalizes the output.
"""
if isinstance(output, dict):
return output
return output.model_dump()
def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable:
"""Decorator that handles task failures uniformly.
@@ -242,7 +233,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe
ctx.log(f"get_participants: transcript_id={input.transcript_id}")
logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id)
recording_data = _to_dict(ctx.task_output(get_recording))
recording_data = to_dict(ctx.task_output(get_recording))
mtg_session_id = recording_data.get("mtg_session_id")
async with fresh_db_connection():
@@ -341,7 +332,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
transcript_id=input.transcript_id,
)
participants_data = _to_dict(ctx.task_output(get_participants))
participants_data = to_dict(ctx.task_output(get_participants))
source_language = participants_data.get("source_language", "en")
child_coroutines = [
@@ -417,7 +408,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
ctx.log("mixdown_tracks: mixing padded tracks into single audio file")
logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id)
track_data = _to_dict(ctx.task_output(process_tracks))
track_data = to_dict(ctx.task_output(process_tracks))
padded_tracks_data = track_data.get("padded_tracks", [])
if not padded_tracks_data:
@@ -428,7 +419,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
# Presign URLs on demand (avoids stale URLs on workflow replay)
padded_urls = []
for track_info in padded_tracks_data:
# Handle both dict (from _to_dict) and PaddedTrackInfo
# Handle both dict (from to_dict) and PaddedTrackInfo
if isinstance(track_info, dict):
key = track_info.get("key")
bucket = track_info.get("bucket_name")
@@ -445,149 +436,36 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
)
padded_urls.append(url)
# Use PipelineMainMultitrack.mixdown_tracks which uses PyAV filter graph
valid_urls = [url for url in padded_urls if url]
if not valid_urls:
raise ValueError("No valid padded tracks to mixdown")
target_sample_rate = None
for url in valid_urls:
container = None
try:
container = av.open(url)
for frame in container.decode(audio=0):
target_sample_rate = frame.sample_rate
break
except Exception:
continue
finally:
if container is not None:
container.close()
if target_sample_rate:
break
# Detect sample rate from tracks
target_sample_rate = detect_sample_rate_from_tracks(valid_urls, logger=logger)
if not target_sample_rate:
logger.error("Mixdown failed - no decodable audio frames found")
raise ValueError("No decodable audio frames in any track")
# Build PyAV filter graph: N abuffer -> amix -> aformat -> sink
graph = av.filter.Graph()
inputs = []
for idx, url in enumerate(valid_urls):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}",
name="fmt",
)
sink = graph.add("abuffersink", name="out")
for idx, in_ctx in enumerate(inputs):
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
# Create temp file and writer for MP3 output
output_path = tempfile.mktemp(suffix=".mp3")
containers = []
duration_ms = [0.0] # Mutable container for callback capture
try:
for url in valid_urls:
try:
c = av.open(
url,
options={
"reconnect": "1",
"reconnect_streamed": "1",
"reconnect_delay_max": "5",
},
)
containers.append(c)
except Exception as e:
logger.warning(
"[Hatchet] mixdown: failed to open container",
url=url,
error=str(e),
)
async def capture_duration(d):
duration_ms[0] = d
if not containers:
raise ValueError("Could not open any track containers")
writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration)
# Create AudioFileWriterProcessor for MP3 output with duration capture
duration_ms = [0.0] # Mutable container for callback capture
async def capture_duration(d):
duration_ms[0] = d
writer = AudioFileWriterProcessor(
path=output_path, on_duration=capture_duration
)
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
inputs[i].push(None)
continue
if frame.sample_rate != target_sample_rate:
continue
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
# Flush remaining frames
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
await writer.flush()
# Duration is captured via callback in milliseconds (from AudioFileWriterProcessor)
finally:
for c in containers:
try:
c.close()
except Exception:
pass
# Run mixdown using shared utility
await mixdown_tracks_pyav(
valid_urls,
writer,
target_sample_rate,
offsets_seconds=None,
logger=logger,
)
await writer.flush()
# Upload to storage
file_size = Path(output_path).stat().st_size
storage_path = f"{input.transcript_id}/audio.mp3"
@@ -596,6 +474,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
Path(output_path).unlink(missing_ok=True)
# Update DB with audio location
async with fresh_db_connection():
from reflector.db.transcripts import transcripts_controller # noqa: PLC0415
@@ -633,7 +512,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
)
# Cleanup temporary padded S3 files (deferred until after mixdown)
track_data = _to_dict(ctx.task_output(process_tracks))
track_data = to_dict(ctx.task_output(process_tracks))
created_padded_files = track_data.get("created_padded_files", [])
if created_padded_files:
logger.info(
@@ -653,7 +532,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul
error=str(result),
)
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
mixdown_data = to_dict(ctx.task_output(mixdown_tracks))
audio_key = mixdown_data.get("audio_key")
storage = _get_storage()
@@ -703,7 +582,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult:
ctx.log("detect_topics: analyzing transcript for topics")
logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id)
track_data = _to_dict(ctx.task_output(process_tracks))
track_data = to_dict(ctx.task_output(process_tracks))
words = track_data.get("all_words", [])
target_language = track_data.get("target_language", "en")
@@ -762,7 +641,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult:
ctx.log("generate_title: generating title from topics")
logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id)
topics_data = _to_dict(ctx.task_output(detect_topics))
topics_data = to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
from reflector.db.transcripts import ( # noqa: PLC0415
@@ -813,7 +692,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult:
ctx.log("generate_summary: generating long and short summaries")
logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id)
topics_data = _to_dict(ctx.task_output(detect_topics))
topics_data = to_dict(ctx.task_output(detect_topics))
topics = topics_data.get("topics", [])
from reflector.db.transcripts import ( # noqa: PLC0415
@@ -895,8 +774,8 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult:
ctx.log("finalize: saving transcript and setting status to 'ended'")
logger.info("[Hatchet] finalize", transcript_id=input.transcript_id)
mixdown_data = _to_dict(ctx.task_output(mixdown_tracks))
track_data = _to_dict(ctx.task_output(process_tracks))
mixdown_data = to_dict(ctx.task_output(mixdown_tracks))
track_data = to_dict(ctx.task_output(process_tracks))
duration = mixdown_data.get("duration", 0)
all_words = track_data.get("all_words", [])

View File

@@ -14,34 +14,25 @@ Hatchet workers run in forked processes; fresh imports per task ensure
storage/DB connections are not shared across forks.
"""
import math
import tempfile
from datetime import timedelta
from fractions import Fraction
from pathlib import Path
import av
from av.audio.resampler import AudioResampler
from hatchet_sdk import Context
from pydantic import BaseModel
from reflector.hatchet.client import HatchetClientManager
from reflector.hatchet.utils import to_dict
from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult
from reflector.logger import logger
from reflector.utils.audio_constants import (
OPUS_DEFAULT_BIT_RATE,
OPUS_STANDARD_SAMPLE_RATE,
PRESIGNED_URL_EXPIRATION_SECONDS,
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
from reflector.utils.audio_padding import (
apply_audio_padding_to_file,
extract_stream_start_time_from_container,
)
def _to_dict(output) -> dict:
"""Convert task output to dict, handling both dict and Pydantic model returns."""
if isinstance(output, dict):
return output
return output.model_dump()
class TrackInput(BaseModel):
"""Input for individual track processing."""
@@ -58,124 +49,6 @@ hatchet = HatchetClientManager.get_client()
track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput)
def _extract_stream_start_time_from_container(container, track_idx: int) -> float:
"""Extract meeting-relative start time from WebM stream metadata.
Uses PyAV to read stream.start_time from WebM container.
More accurate than filename timestamps by ~209ms due to network/encoding delays.
"""
start_time_seconds = 0.0
try:
audio_streams = [s for s in container.streams if s.type == "audio"]
stream = audio_streams[0] if audio_streams else container.streams[0]
# 1) Try stream-level start_time (most reliable for Daily.co tracks)
if stream.start_time is not None and stream.time_base is not None:
start_time_seconds = float(stream.start_time * stream.time_base)
# 2) Fallback to container-level start_time
if (start_time_seconds <= 0) and (container.start_time is not None):
start_time_seconds = float(container.start_time * av.time_base)
# 3) Fallback to first packet DTS
if start_time_seconds <= 0:
for packet in container.demux(stream):
if packet.dts is not None:
start_time_seconds = float(packet.dts * stream.time_base)
break
except Exception as e:
logger.warning(
"PyAV metadata read failed; assuming 0 start_time",
track_idx=track_idx,
error=str(e),
)
start_time_seconds = 0.0
logger.info(
f"Track {track_idx} stream metadata: start_time={start_time_seconds:.3f}s",
track_idx=track_idx,
)
return start_time_seconds
def _apply_audio_padding_to_file(
in_container,
output_path: str,
start_time_seconds: float,
track_idx: int,
) -> None:
"""Apply silence padding to audio track using PyAV filter graph."""
delay_ms = math.floor(start_time_seconds * 1000)
logger.info(
f"Padding track {track_idx} with {delay_ms}ms delay using PyAV",
track_idx=track_idx,
delay_ms=delay_ms,
)
with av.open(output_path, "w", format="webm") as out_container:
in_stream = next((s for s in in_container.streams if s.type == "audio"), None)
if in_stream is None:
raise Exception("No audio stream in input")
out_stream = out_container.add_stream("libopus", rate=OPUS_STANDARD_SAMPLE_RATE)
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
graph = av.filter.Graph()
abuf_args = (
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_fmt=s16:"
f"channel_layout=stereo"
)
src = graph.add("abuffer", args=abuf_args, name="src")
aresample_f = graph.add("aresample", args="async=1", name="ares")
delays_arg = f"{delay_ms}|{delay_ms}"
adelay_f = graph.add("adelay", args=f"delays={delays_arg}:all=1", name="delay")
sink = graph.add("abuffersink", name="sink")
src.link_to(aresample_f)
aresample_f.link_to(adelay_f)
adelay_f.link_to(sink)
graph.configure()
resampler = AudioResampler(
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
)
for frame in in_container.decode(in_stream):
out_frames = resampler.resample(frame) or []
for rframe in out_frames:
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
src.push(rframe)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
# Flush remaining frames
src.push(None)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
@track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3)
async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
"""Pad single audio track with silence for alignment.
@@ -213,8 +86,8 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
# Open container and extract start time
with av.open(source_url) as in_container:
start_time_seconds = _extract_stream_start_time_from_container(
in_container, input.track_index
start_time_seconds = extract_stream_start_time_from_container(
in_container, input.track_index, logger=logger
)
# If no padding needed, return original S3 key
@@ -235,8 +108,12 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult:
temp_path = temp_file.name
try:
_apply_audio_padding_to_file(
in_container, temp_path, start_time_seconds, input.track_index
apply_audio_padding_to_file(
in_container,
temp_path,
start_time_seconds,
input.track_index,
logger=logger,
)
file_size = Path(temp_path).stat().st_size
@@ -293,7 +170,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe
)
try:
pad_result = _to_dict(ctx.task_output(pad_track))
pad_result = to_dict(ctx.task_output(pad_track))
padded_key = pad_result.get("padded_key")
bucket_name = pad_result.get("bucket_name")

View File

@@ -1,11 +1,8 @@
import asyncio
import math
import tempfile
from fractions import Fraction
from pathlib import Path
import av
from av.audio.resampler import AudioResampler
from celery import chain, shared_task
from reflector.asynctask import asynctask
@@ -32,10 +29,14 @@ from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import TitleSummary
from reflector.processors.types import Transcript as TranscriptType
from reflector.storage import Storage, get_transcripts_storage
from reflector.utils.audio_constants import (
OPUS_DEFAULT_BIT_RATE,
OPUS_STANDARD_SAMPLE_RATE,
PRESIGNED_URL_EXPIRATION_SECONDS,
from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS
from reflector.utils.audio_mixdown import (
detect_sample_rate_from_tracks,
mixdown_tracks_pyav,
)
from reflector.utils.audio_padding import (
apply_audio_padding_to_file,
extract_stream_start_time_from_container,
)
from reflector.utils.daily import (
filter_cam_audio_tracks,
@@ -123,8 +124,8 @@ class PipelineMainMultitrack(PipelineMainBase):
try:
# PyAV streams input from S3 URL efficiently (2-5MB fixed overhead for codec/filters)
with av.open(track_url) as in_container:
start_time_seconds = self._extract_stream_start_time_from_container(
in_container, track_idx
start_time_seconds = extract_stream_start_time_from_container(
in_container, track_idx, logger=self.logger
)
if start_time_seconds <= 0:
@@ -142,8 +143,12 @@ class PipelineMainMultitrack(PipelineMainBase):
temp_path = temp_file.name
try:
self._apply_audio_padding_to_file(
in_container, temp_path, start_time_seconds, track_idx
apply_audio_padding_to_file(
in_container,
temp_path,
start_time_seconds,
track_idx,
logger=self.logger,
)
storage_path = (
@@ -184,317 +189,30 @@ class PipelineMainMultitrack(PipelineMainBase):
f"Track {track_idx} padding failed - transcript would have incorrect timestamps"
) from e
def _extract_stream_start_time_from_container(
self, container, track_idx: int
) -> float:
"""
Extract meeting-relative start time from WebM stream metadata.
Uses PyAV to read stream.start_time from WebM container.
More accurate than filename timestamps by ~209ms due to network/encoding delays.
"""
start_time_seconds = 0.0
try:
audio_streams = [s for s in container.streams if s.type == "audio"]
stream = audio_streams[0] if audio_streams else container.streams[0]
# 1) Try stream-level start_time (most reliable for Daily.co tracks)
if stream.start_time is not None and stream.time_base is not None:
start_time_seconds = float(stream.start_time * stream.time_base)
# 2) Fallback to container-level start_time (in av.time_base units)
if (start_time_seconds <= 0) and (container.start_time is not None):
start_time_seconds = float(container.start_time * av.time_base)
# 3) Fallback to first packet DTS in stream.time_base
if start_time_seconds <= 0:
for packet in container.demux(stream):
if packet.dts is not None:
start_time_seconds = float(packet.dts * stream.time_base)
break
except Exception as e:
self.logger.warning(
"PyAV metadata read failed; assuming 0 start_time",
track_idx=track_idx,
error=str(e),
)
start_time_seconds = 0.0
self.logger.info(
f"Track {track_idx} stream metadata: start_time={start_time_seconds:.3f}s",
track_idx=track_idx,
)
return start_time_seconds
def _apply_audio_padding_to_file(
self,
in_container,
output_path: str,
start_time_seconds: float,
track_idx: int,
) -> None:
"""Apply silence padding to audio track using PyAV filter graph, writing to file"""
delay_ms = math.floor(start_time_seconds * 1000)
self.logger.info(
f"Padding track {track_idx} with {delay_ms}ms delay using PyAV",
track_idx=track_idx,
delay_ms=delay_ms,
)
try:
with av.open(output_path, "w", format="webm") as out_container:
in_stream = next(
(s for s in in_container.streams if s.type == "audio"), None
)
if in_stream is None:
raise Exception("No audio stream in input")
out_stream = out_container.add_stream(
"libopus", rate=OPUS_STANDARD_SAMPLE_RATE
)
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
graph = av.filter.Graph()
abuf_args = (
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_fmt=s16:"
f"channel_layout=stereo"
)
src = graph.add("abuffer", args=abuf_args, name="src")
aresample_f = graph.add("aresample", args="async=1", name="ares")
# adelay requires one delay value per channel separated by '|'
delays_arg = f"{delay_ms}|{delay_ms}"
adelay_f = graph.add(
"adelay", args=f"delays={delays_arg}:all=1", name="delay"
)
sink = graph.add("abuffersink", name="sink")
src.link_to(aresample_f)
aresample_f.link_to(adelay_f)
adelay_f.link_to(sink)
graph.configure()
resampler = AudioResampler(
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
)
# Decode -> resample -> push through graph -> encode Opus
for frame in in_container.decode(in_stream):
out_frames = resampler.resample(frame) or []
for rframe in out_frames:
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
src.push(rframe)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
src.push(None)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
except Exception as e:
self.logger.error(
"PyAV padding failed for track",
track_idx=track_idx,
delay_ms=delay_ms,
error=str(e),
exc_info=True,
)
raise
async def mixdown_tracks(
self,
track_urls: list[str],
writer: AudioFileWriterProcessor,
offsets_seconds: list[float] | None = None,
) -> None:
"""Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs"""
target_sample_rate: int | None = None
for url in track_urls:
if not url:
continue
container = None
try:
container = av.open(url)
for frame in container.decode(audio=0):
target_sample_rate = frame.sample_rate
break
except Exception:
continue
finally:
if container is not None:
container.close()
if target_sample_rate:
break
"""Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs."""
# Detect sample rate from tracks
target_sample_rate = detect_sample_rate_from_tracks(
track_urls, logger=self.logger
)
if not target_sample_rate:
self.logger.error("Mixdown failed - no decodable audio frames found")
raise Exception("Mixdown failed: No decodable audio frames in any track")
# Build PyAV filter graph:
# N abuffer (s32/stereo)
# -> optional adelay per input (for alignment)
# -> amix (s32)
# -> aformat(s16)
# -> sink
graph = av.filter.Graph()
inputs = []
valid_track_urls = [url for url in track_urls if url]
input_offsets_seconds = None
if offsets_seconds is not None:
input_offsets_seconds = [
offsets_seconds[i] for i, url in enumerate(track_urls) if url
]
for idx, url in enumerate(valid_track_urls):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
if not inputs:
self.logger.error("Mixdown failed - no valid inputs for graph")
raise Exception("Mixdown failed: No valid inputs for filter graph")
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=(
f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}"
),
name="fmt",
# Run mixdown using shared utility
await mixdown_tracks_pyav(
track_urls,
writer,
target_sample_rate,
offsets_seconds=offsets_seconds,
logger=self.logger,
)
sink = graph.add("abuffersink", name="out")
# Optional per-input delay before mixing
delays_ms: list[int] = []
if input_offsets_seconds is not None:
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
delays_ms = [
max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds
]
else:
delays_ms = [0 for _ in inputs]
for idx, in_ctx in enumerate(inputs):
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
if delay_ms > 0:
# adelay requires one value per channel; use same for stereo
adelay = graph.add(
"adelay",
args=f"delays={delay_ms}|{delay_ms}:all=1",
name=f"delay{idx}",
)
in_ctx.link_to(adelay)
adelay.link_to(mixer, 0, idx)
else:
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
containers = []
try:
# Open all containers with cleanup guaranteed
for i, url in enumerate(valid_track_urls):
try:
c = av.open(
url,
options={
# it's trying to stream from s3 by default
"reconnect": "1",
"reconnect_streamed": "1",
"reconnect_delay_max": "5",
},
)
containers.append(c)
except Exception as e:
self.logger.warning(
"Mixdown: failed to open container from URL",
input=i,
url=url,
error=str(e),
)
if not containers:
self.logger.error("Mixdown failed - no valid containers opened")
raise Exception("Mixdown failed: Could not open any track containers")
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
# causes stream to move on / unclogs memory
inputs[i].push(None)
continue
if frame.sample_rate != target_sample_rate:
continue
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
finally:
# Cleanup all containers, even if processing failed
for c in containers:
if c is not None:
try:
c.close()
except Exception:
pass # Best effort cleanup
@broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction():