From d683a83906e616d7a4f84986eee0eee23b3eca3a Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Wed, 17 Dec 2025 14:48:23 -0500 Subject: [PATCH] dry hatchet with celery --- .../hatchet/workflows/diarization_pipeline.py | 191 ++-------- .../hatchet/workflows/track_processing.py | 151 +------- .../pipelines/main_multitrack_pipeline.py | 338 ++---------------- 3 files changed, 77 insertions(+), 603 deletions(-) diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index ffc89d2d..f8121901 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -14,13 +14,10 @@ import functools import tempfile from contextlib import asynccontextmanager from datetime import timedelta -from fractions import Fraction from pathlib import Path from typing import Callable -import av import httpx -from av.audio.resampler import AudioResampler from hatchet_sdk import Context from pydantic import BaseModel @@ -30,6 +27,7 @@ from reflector.hatchet.broadcast import ( set_status_and_broadcast, ) from reflector.hatchet.client import HatchetClientManager +from reflector.hatchet.utils import to_dict from reflector.hatchet.workflows.models import ( ConsentResult, FinalizeResult, @@ -62,6 +60,10 @@ from reflector.utils.audio_constants import ( PRESIGNED_URL_EXPIRATION_SECONDS, WAVEFORM_SEGMENTS, ) +from reflector.utils.audio_mixdown import ( + detect_sample_rate_from_tracks, + mixdown_tracks_pyav, +) from reflector.utils.audio_waveform import get_audio_waveform from reflector.utils.daily import ( filter_cam_audio_tracks, @@ -146,17 +148,6 @@ def _get_storage(): ) -def _to_dict(output) -> dict: - """Convert task output to dict, handling both dict and Pydantic model returns. - - Hatchet SDK returns Pydantic models when tasks have typed return annotations, - but older code expects dicts. This helper normalizes the output. - """ - if isinstance(output, dict): - return output - return output.model_dump() - - def with_error_handling(step_name: str, set_error_status: bool = True) -> Callable: """Decorator that handles task failures uniformly. @@ -242,7 +233,7 @@ async def get_participants(input: PipelineInput, ctx: Context) -> ParticipantsRe ctx.log(f"get_participants: transcript_id={input.transcript_id}") logger.info("[Hatchet] get_participants", transcript_id=input.transcript_id) - recording_data = _to_dict(ctx.task_output(get_recording)) + recording_data = to_dict(ctx.task_output(get_recording)) mtg_session_id = recording_data.get("mtg_session_id") async with fresh_db_connection(): @@ -341,7 +332,7 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes transcript_id=input.transcript_id, ) - participants_data = _to_dict(ctx.task_output(get_participants)) + participants_data = to_dict(ctx.task_output(get_participants)) source_language = participants_data.get("source_language", "en") child_coroutines = [ @@ -417,7 +408,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: ctx.log("mixdown_tracks: mixing padded tracks into single audio file") logger.info("[Hatchet] mixdown_tracks", transcript_id=input.transcript_id) - track_data = _to_dict(ctx.task_output(process_tracks)) + track_data = to_dict(ctx.task_output(process_tracks)) padded_tracks_data = track_data.get("padded_tracks", []) if not padded_tracks_data: @@ -428,7 +419,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: # Presign URLs on demand (avoids stale URLs on workflow replay) padded_urls = [] for track_info in padded_tracks_data: - # Handle both dict (from _to_dict) and PaddedTrackInfo + # Handle both dict (from to_dict) and PaddedTrackInfo if isinstance(track_info, dict): key = track_info.get("key") bucket = track_info.get("bucket_name") @@ -445,149 +436,36 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: ) padded_urls.append(url) - # Use PipelineMainMultitrack.mixdown_tracks which uses PyAV filter graph valid_urls = [url for url in padded_urls if url] if not valid_urls: raise ValueError("No valid padded tracks to mixdown") - target_sample_rate = None - for url in valid_urls: - container = None - try: - container = av.open(url) - for frame in container.decode(audio=0): - target_sample_rate = frame.sample_rate - break - except Exception: - continue - finally: - if container is not None: - container.close() - if target_sample_rate: - break - + # Detect sample rate from tracks + target_sample_rate = detect_sample_rate_from_tracks(valid_urls, logger=logger) if not target_sample_rate: + logger.error("Mixdown failed - no decodable audio frames found") raise ValueError("No decodable audio frames in any track") - # Build PyAV filter graph: N abuffer -> amix -> aformat -> sink - graph = av.filter.Graph() - inputs = [] - - for idx, url in enumerate(valid_urls): - args = ( - f"time_base=1/{target_sample_rate}:" - f"sample_rate={target_sample_rate}:" - f"sample_fmt=s32:" - f"channel_layout=stereo" - ) - in_ctx = graph.add("abuffer", args=args, name=f"in{idx}") - inputs.append(in_ctx) - - mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix") - fmt = graph.add( - "aformat", - args=f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}", - name="fmt", - ) - sink = graph.add("abuffersink", name="out") - - for idx, in_ctx in enumerate(inputs): - in_ctx.link_to(mixer, 0, idx) - mixer.link_to(fmt) - fmt.link_to(sink) - graph.configure() - + # Create temp file and writer for MP3 output output_path = tempfile.mktemp(suffix=".mp3") - containers = [] + duration_ms = [0.0] # Mutable container for callback capture - try: - for url in valid_urls: - try: - c = av.open( - url, - options={ - "reconnect": "1", - "reconnect_streamed": "1", - "reconnect_delay_max": "5", - }, - ) - containers.append(c) - except Exception as e: - logger.warning( - "[Hatchet] mixdown: failed to open container", - url=url, - error=str(e), - ) + async def capture_duration(d): + duration_ms[0] = d - if not containers: - raise ValueError("Could not open any track containers") + writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration) - # Create AudioFileWriterProcessor for MP3 output with duration capture - duration_ms = [0.0] # Mutable container for callback capture - - async def capture_duration(d): - duration_ms[0] = d - - writer = AudioFileWriterProcessor( - path=output_path, on_duration=capture_duration - ) - - decoders = [c.decode(audio=0) for c in containers] - active = [True] * len(decoders) - resamplers = [ - AudioResampler(format="s32", layout="stereo", rate=target_sample_rate) - for _ in decoders - ] - - while any(active): - for i, (dec, is_active) in enumerate(zip(decoders, active)): - if not is_active: - continue - try: - frame = next(dec) - except StopIteration: - active[i] = False - inputs[i].push(None) - continue - - if frame.sample_rate != target_sample_rate: - continue - out_frames = resamplers[i].resample(frame) or [] - for rf in out_frames: - rf.sample_rate = target_sample_rate - rf.time_base = Fraction(1, target_sample_rate) - inputs[i].push(rf) - - while True: - try: - mixed = sink.pull() - except Exception: - break - mixed.sample_rate = target_sample_rate - mixed.time_base = Fraction(1, target_sample_rate) - await writer.push(mixed) - - # Flush remaining frames - while True: - try: - mixed = sink.pull() - except Exception: - break - mixed.sample_rate = target_sample_rate - mixed.time_base = Fraction(1, target_sample_rate) - await writer.push(mixed) - - await writer.flush() - - # Duration is captured via callback in milliseconds (from AudioFileWriterProcessor) - - finally: - for c in containers: - try: - c.close() - except Exception: - pass + # Run mixdown using shared utility + await mixdown_tracks_pyav( + valid_urls, + writer, + target_sample_rate, + offsets_seconds=None, + logger=logger, + ) + await writer.flush() + # Upload to storage file_size = Path(output_path).stat().st_size storage_path = f"{input.transcript_id}/audio.mp3" @@ -596,6 +474,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: Path(output_path).unlink(missing_ok=True) + # Update DB with audio location async with fresh_db_connection(): from reflector.db.transcripts import transcripts_controller # noqa: PLC0415 @@ -633,7 +512,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul ) # Cleanup temporary padded S3 files (deferred until after mixdown) - track_data = _to_dict(ctx.task_output(process_tracks)) + track_data = to_dict(ctx.task_output(process_tracks)) created_padded_files = track_data.get("created_padded_files", []) if created_padded_files: logger.info( @@ -653,7 +532,7 @@ async def generate_waveform(input: PipelineInput, ctx: Context) -> WaveformResul error=str(result), ) - mixdown_data = _to_dict(ctx.task_output(mixdown_tracks)) + mixdown_data = to_dict(ctx.task_output(mixdown_tracks)) audio_key = mixdown_data.get("audio_key") storage = _get_storage() @@ -703,7 +582,7 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: ctx.log("detect_topics: analyzing transcript for topics") logger.info("[Hatchet] detect_topics", transcript_id=input.transcript_id) - track_data = _to_dict(ctx.task_output(process_tracks)) + track_data = to_dict(ctx.task_output(process_tracks)) words = track_data.get("all_words", []) target_language = track_data.get("target_language", "en") @@ -762,7 +641,7 @@ async def generate_title(input: PipelineInput, ctx: Context) -> TitleResult: ctx.log("generate_title: generating title from topics") logger.info("[Hatchet] generate_title", transcript_id=input.transcript_id) - topics_data = _to_dict(ctx.task_output(detect_topics)) + topics_data = to_dict(ctx.task_output(detect_topics)) topics = topics_data.get("topics", []) from reflector.db.transcripts import ( # noqa: PLC0415 @@ -813,7 +692,7 @@ async def generate_summary(input: PipelineInput, ctx: Context) -> SummaryResult: ctx.log("generate_summary: generating long and short summaries") logger.info("[Hatchet] generate_summary", transcript_id=input.transcript_id) - topics_data = _to_dict(ctx.task_output(detect_topics)) + topics_data = to_dict(ctx.task_output(detect_topics)) topics = topics_data.get("topics", []) from reflector.db.transcripts import ( # noqa: PLC0415 @@ -895,8 +774,8 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: ctx.log("finalize: saving transcript and setting status to 'ended'") logger.info("[Hatchet] finalize", transcript_id=input.transcript_id) - mixdown_data = _to_dict(ctx.task_output(mixdown_tracks)) - track_data = _to_dict(ctx.task_output(process_tracks)) + mixdown_data = to_dict(ctx.task_output(mixdown_tracks)) + track_data = to_dict(ctx.task_output(process_tracks)) duration = mixdown_data.get("duration", 0) all_words = track_data.get("all_words", []) diff --git a/server/reflector/hatchet/workflows/track_processing.py b/server/reflector/hatchet/workflows/track_processing.py index b709578c..b5aaa87e 100644 --- a/server/reflector/hatchet/workflows/track_processing.py +++ b/server/reflector/hatchet/workflows/track_processing.py @@ -14,34 +14,25 @@ Hatchet workers run in forked processes; fresh imports per task ensure storage/DB connections are not shared across forks. """ -import math import tempfile from datetime import timedelta -from fractions import Fraction from pathlib import Path import av -from av.audio.resampler import AudioResampler from hatchet_sdk import Context from pydantic import BaseModel from reflector.hatchet.client import HatchetClientManager +from reflector.hatchet.utils import to_dict from reflector.hatchet.workflows.models import PadTrackResult, TranscribeTrackResult from reflector.logger import logger -from reflector.utils.audio_constants import ( - OPUS_DEFAULT_BIT_RATE, - OPUS_STANDARD_SAMPLE_RATE, - PRESIGNED_URL_EXPIRATION_SECONDS, +from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS +from reflector.utils.audio_padding import ( + apply_audio_padding_to_file, + extract_stream_start_time_from_container, ) -def _to_dict(output) -> dict: - """Convert task output to dict, handling both dict and Pydantic model returns.""" - if isinstance(output, dict): - return output - return output.model_dump() - - class TrackInput(BaseModel): """Input for individual track processing.""" @@ -58,124 +49,6 @@ hatchet = HatchetClientManager.get_client() track_workflow = hatchet.workflow(name="TrackProcessing", input_validator=TrackInput) -def _extract_stream_start_time_from_container(container, track_idx: int) -> float: - """Extract meeting-relative start time from WebM stream metadata. - - Uses PyAV to read stream.start_time from WebM container. - More accurate than filename timestamps by ~209ms due to network/encoding delays. - """ - start_time_seconds = 0.0 - try: - audio_streams = [s for s in container.streams if s.type == "audio"] - stream = audio_streams[0] if audio_streams else container.streams[0] - - # 1) Try stream-level start_time (most reliable for Daily.co tracks) - if stream.start_time is not None and stream.time_base is not None: - start_time_seconds = float(stream.start_time * stream.time_base) - - # 2) Fallback to container-level start_time - if (start_time_seconds <= 0) and (container.start_time is not None): - start_time_seconds = float(container.start_time * av.time_base) - - # 3) Fallback to first packet DTS - if start_time_seconds <= 0: - for packet in container.demux(stream): - if packet.dts is not None: - start_time_seconds = float(packet.dts * stream.time_base) - break - except Exception as e: - logger.warning( - "PyAV metadata read failed; assuming 0 start_time", - track_idx=track_idx, - error=str(e), - ) - start_time_seconds = 0.0 - - logger.info( - f"Track {track_idx} stream metadata: start_time={start_time_seconds:.3f}s", - track_idx=track_idx, - ) - return start_time_seconds - - -def _apply_audio_padding_to_file( - in_container, - output_path: str, - start_time_seconds: float, - track_idx: int, -) -> None: - """Apply silence padding to audio track using PyAV filter graph.""" - delay_ms = math.floor(start_time_seconds * 1000) - - logger.info( - f"Padding track {track_idx} with {delay_ms}ms delay using PyAV", - track_idx=track_idx, - delay_ms=delay_ms, - ) - - with av.open(output_path, "w", format="webm") as out_container: - in_stream = next((s for s in in_container.streams if s.type == "audio"), None) - if in_stream is None: - raise Exception("No audio stream in input") - - out_stream = out_container.add_stream("libopus", rate=OPUS_STANDARD_SAMPLE_RATE) - out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE - graph = av.filter.Graph() - - abuf_args = ( - f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:" - f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:" - f"sample_fmt=s16:" - f"channel_layout=stereo" - ) - src = graph.add("abuffer", args=abuf_args, name="src") - aresample_f = graph.add("aresample", args="async=1", name="ares") - delays_arg = f"{delay_ms}|{delay_ms}" - adelay_f = graph.add("adelay", args=f"delays={delays_arg}:all=1", name="delay") - sink = graph.add("abuffersink", name="sink") - - src.link_to(aresample_f) - aresample_f.link_to(adelay_f) - adelay_f.link_to(sink) - graph.configure() - - resampler = AudioResampler( - format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE - ) - - for frame in in_container.decode(in_stream): - out_frames = resampler.resample(frame) or [] - for rframe in out_frames: - rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE - rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE) - src.push(rframe) - - while True: - try: - f_out = sink.pull() - except Exception: - break - f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE - f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE) - for packet in out_stream.encode(f_out): - out_container.mux(packet) - - # Flush remaining frames - src.push(None) - while True: - try: - f_out = sink.pull() - except Exception: - break - f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE - f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE) - for packet in out_stream.encode(f_out): - out_container.mux(packet) - - for packet in out_stream.encode(None): - out_container.mux(packet) - - @track_workflow.task(execution_timeout=timedelta(seconds=300), retries=3) async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: """Pad single audio track with silence for alignment. @@ -213,8 +86,8 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: # Open container and extract start time with av.open(source_url) as in_container: - start_time_seconds = _extract_stream_start_time_from_container( - in_container, input.track_index + start_time_seconds = extract_stream_start_time_from_container( + in_container, input.track_index, logger=logger ) # If no padding needed, return original S3 key @@ -235,8 +108,12 @@ async def pad_track(input: TrackInput, ctx: Context) -> PadTrackResult: temp_path = temp_file.name try: - _apply_audio_padding_to_file( - in_container, temp_path, start_time_seconds, input.track_index + apply_audio_padding_to_file( + in_container, + temp_path, + start_time_seconds, + input.track_index, + logger=logger, ) file_size = Path(temp_path).stat().st_size @@ -293,7 +170,7 @@ async def transcribe_track(input: TrackInput, ctx: Context) -> TranscribeTrackRe ) try: - pad_result = _to_dict(ctx.task_output(pad_track)) + pad_result = to_dict(ctx.task_output(pad_track)) padded_key = pad_result.get("padded_key") bucket_name = pad_result.get("bucket_name") diff --git a/server/reflector/pipelines/main_multitrack_pipeline.py b/server/reflector/pipelines/main_multitrack_pipeline.py index 26f42c4f..72efbf5a 100644 --- a/server/reflector/pipelines/main_multitrack_pipeline.py +++ b/server/reflector/pipelines/main_multitrack_pipeline.py @@ -1,11 +1,8 @@ import asyncio -import math import tempfile -from fractions import Fraction from pathlib import Path import av -from av.audio.resampler import AudioResampler from celery import chain, shared_task from reflector.asynctask import asynctask @@ -32,10 +29,14 @@ from reflector.processors.audio_waveform_processor import AudioWaveformProcessor from reflector.processors.types import TitleSummary from reflector.processors.types import Transcript as TranscriptType from reflector.storage import Storage, get_transcripts_storage -from reflector.utils.audio_constants import ( - OPUS_DEFAULT_BIT_RATE, - OPUS_STANDARD_SAMPLE_RATE, - PRESIGNED_URL_EXPIRATION_SECONDS, +from reflector.utils.audio_constants import PRESIGNED_URL_EXPIRATION_SECONDS +from reflector.utils.audio_mixdown import ( + detect_sample_rate_from_tracks, + mixdown_tracks_pyav, +) +from reflector.utils.audio_padding import ( + apply_audio_padding_to_file, + extract_stream_start_time_from_container, ) from reflector.utils.daily import ( filter_cam_audio_tracks, @@ -123,8 +124,8 @@ class PipelineMainMultitrack(PipelineMainBase): try: # PyAV streams input from S3 URL efficiently (2-5MB fixed overhead for codec/filters) with av.open(track_url) as in_container: - start_time_seconds = self._extract_stream_start_time_from_container( - in_container, track_idx + start_time_seconds = extract_stream_start_time_from_container( + in_container, track_idx, logger=self.logger ) if start_time_seconds <= 0: @@ -142,8 +143,12 @@ class PipelineMainMultitrack(PipelineMainBase): temp_path = temp_file.name try: - self._apply_audio_padding_to_file( - in_container, temp_path, start_time_seconds, track_idx + apply_audio_padding_to_file( + in_container, + temp_path, + start_time_seconds, + track_idx, + logger=self.logger, ) storage_path = ( @@ -184,317 +189,30 @@ class PipelineMainMultitrack(PipelineMainBase): f"Track {track_idx} padding failed - transcript would have incorrect timestamps" ) from e - def _extract_stream_start_time_from_container( - self, container, track_idx: int - ) -> float: - """ - Extract meeting-relative start time from WebM stream metadata. - Uses PyAV to read stream.start_time from WebM container. - More accurate than filename timestamps by ~209ms due to network/encoding delays. - """ - start_time_seconds = 0.0 - try: - audio_streams = [s for s in container.streams if s.type == "audio"] - stream = audio_streams[0] if audio_streams else container.streams[0] - - # 1) Try stream-level start_time (most reliable for Daily.co tracks) - if stream.start_time is not None and stream.time_base is not None: - start_time_seconds = float(stream.start_time * stream.time_base) - - # 2) Fallback to container-level start_time (in av.time_base units) - if (start_time_seconds <= 0) and (container.start_time is not None): - start_time_seconds = float(container.start_time * av.time_base) - - # 3) Fallback to first packet DTS in stream.time_base - if start_time_seconds <= 0: - for packet in container.demux(stream): - if packet.dts is not None: - start_time_seconds = float(packet.dts * stream.time_base) - break - except Exception as e: - self.logger.warning( - "PyAV metadata read failed; assuming 0 start_time", - track_idx=track_idx, - error=str(e), - ) - start_time_seconds = 0.0 - - self.logger.info( - f"Track {track_idx} stream metadata: start_time={start_time_seconds:.3f}s", - track_idx=track_idx, - ) - return start_time_seconds - - def _apply_audio_padding_to_file( - self, - in_container, - output_path: str, - start_time_seconds: float, - track_idx: int, - ) -> None: - """Apply silence padding to audio track using PyAV filter graph, writing to file""" - delay_ms = math.floor(start_time_seconds * 1000) - - self.logger.info( - f"Padding track {track_idx} with {delay_ms}ms delay using PyAV", - track_idx=track_idx, - delay_ms=delay_ms, - ) - - try: - with av.open(output_path, "w", format="webm") as out_container: - in_stream = next( - (s for s in in_container.streams if s.type == "audio"), None - ) - if in_stream is None: - raise Exception("No audio stream in input") - - out_stream = out_container.add_stream( - "libopus", rate=OPUS_STANDARD_SAMPLE_RATE - ) - out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE - graph = av.filter.Graph() - - abuf_args = ( - f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:" - f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:" - f"sample_fmt=s16:" - f"channel_layout=stereo" - ) - src = graph.add("abuffer", args=abuf_args, name="src") - aresample_f = graph.add("aresample", args="async=1", name="ares") - # adelay requires one delay value per channel separated by '|' - delays_arg = f"{delay_ms}|{delay_ms}" - adelay_f = graph.add( - "adelay", args=f"delays={delays_arg}:all=1", name="delay" - ) - sink = graph.add("abuffersink", name="sink") - - src.link_to(aresample_f) - aresample_f.link_to(adelay_f) - adelay_f.link_to(sink) - graph.configure() - - resampler = AudioResampler( - format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE - ) - # Decode -> resample -> push through graph -> encode Opus - for frame in in_container.decode(in_stream): - out_frames = resampler.resample(frame) or [] - for rframe in out_frames: - rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE - rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE) - src.push(rframe) - - while True: - try: - f_out = sink.pull() - except Exception: - break - f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE - f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE) - for packet in out_stream.encode(f_out): - out_container.mux(packet) - - src.push(None) - while True: - try: - f_out = sink.pull() - except Exception: - break - f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE - f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE) - for packet in out_stream.encode(f_out): - out_container.mux(packet) - - for packet in out_stream.encode(None): - out_container.mux(packet) - except Exception as e: - self.logger.error( - "PyAV padding failed for track", - track_idx=track_idx, - delay_ms=delay_ms, - error=str(e), - exc_info=True, - ) - raise - async def mixdown_tracks( self, track_urls: list[str], writer: AudioFileWriterProcessor, offsets_seconds: list[float] | None = None, ) -> None: - """Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs""" - - target_sample_rate: int | None = None - for url in track_urls: - if not url: - continue - container = None - try: - container = av.open(url) - for frame in container.decode(audio=0): - target_sample_rate = frame.sample_rate - break - except Exception: - continue - finally: - if container is not None: - container.close() - if target_sample_rate: - break - + """Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs.""" + # Detect sample rate from tracks + target_sample_rate = detect_sample_rate_from_tracks( + track_urls, logger=self.logger + ) if not target_sample_rate: self.logger.error("Mixdown failed - no decodable audio frames found") raise Exception("Mixdown failed: No decodable audio frames in any track") - # Build PyAV filter graph: - # N abuffer (s32/stereo) - # -> optional adelay per input (for alignment) - # -> amix (s32) - # -> aformat(s16) - # -> sink - graph = av.filter.Graph() - inputs = [] - valid_track_urls = [url for url in track_urls if url] - input_offsets_seconds = None - if offsets_seconds is not None: - input_offsets_seconds = [ - offsets_seconds[i] for i, url in enumerate(track_urls) if url - ] - for idx, url in enumerate(valid_track_urls): - args = ( - f"time_base=1/{target_sample_rate}:" - f"sample_rate={target_sample_rate}:" - f"sample_fmt=s32:" - f"channel_layout=stereo" - ) - in_ctx = graph.add("abuffer", args=args, name=f"in{idx}") - inputs.append(in_ctx) - if not inputs: - self.logger.error("Mixdown failed - no valid inputs for graph") - raise Exception("Mixdown failed: No valid inputs for filter graph") - - mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix") - - fmt = graph.add( - "aformat", - args=( - f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}" - ), - name="fmt", + # Run mixdown using shared utility + await mixdown_tracks_pyav( + track_urls, + writer, + target_sample_rate, + offsets_seconds=offsets_seconds, + logger=self.logger, ) - sink = graph.add("abuffersink", name="out") - - # Optional per-input delay before mixing - delays_ms: list[int] = [] - if input_offsets_seconds is not None: - base = min(input_offsets_seconds) if input_offsets_seconds else 0.0 - delays_ms = [ - max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds - ] - else: - delays_ms = [0 for _ in inputs] - - for idx, in_ctx in enumerate(inputs): - delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0 - if delay_ms > 0: - # adelay requires one value per channel; use same for stereo - adelay = graph.add( - "adelay", - args=f"delays={delay_ms}|{delay_ms}:all=1", - name=f"delay{idx}", - ) - in_ctx.link_to(adelay) - adelay.link_to(mixer, 0, idx) - else: - in_ctx.link_to(mixer, 0, idx) - mixer.link_to(fmt) - fmt.link_to(sink) - graph.configure() - - containers = [] - try: - # Open all containers with cleanup guaranteed - for i, url in enumerate(valid_track_urls): - try: - c = av.open( - url, - options={ - # it's trying to stream from s3 by default - "reconnect": "1", - "reconnect_streamed": "1", - "reconnect_delay_max": "5", - }, - ) - containers.append(c) - except Exception as e: - self.logger.warning( - "Mixdown: failed to open container from URL", - input=i, - url=url, - error=str(e), - ) - - if not containers: - self.logger.error("Mixdown failed - no valid containers opened") - raise Exception("Mixdown failed: Could not open any track containers") - - decoders = [c.decode(audio=0) for c in containers] - active = [True] * len(decoders) - resamplers = [ - AudioResampler(format="s32", layout="stereo", rate=target_sample_rate) - for _ in decoders - ] - - while any(active): - for i, (dec, is_active) in enumerate(zip(decoders, active)): - if not is_active: - continue - try: - frame = next(dec) - except StopIteration: - active[i] = False - # causes stream to move on / unclogs memory - inputs[i].push(None) - continue - - if frame.sample_rate != target_sample_rate: - continue - out_frames = resamplers[i].resample(frame) or [] - for rf in out_frames: - rf.sample_rate = target_sample_rate - rf.time_base = Fraction(1, target_sample_rate) - inputs[i].push(rf) - - while True: - try: - mixed = sink.pull() - except Exception: - break - mixed.sample_rate = target_sample_rate - mixed.time_base = Fraction(1, target_sample_rate) - await writer.push(mixed) - - while True: - try: - mixed = sink.pull() - except Exception: - break - mixed.sample_rate = target_sample_rate - mixed.time_base = Fraction(1, target_sample_rate) - await writer.push(mixed) - finally: - # Cleanup all containers, even if processing failed - for c in containers: - if c is not None: - try: - c.close() - except Exception: - pass # Best effort cleanup - @broadcast_to_sockets async def set_status(self, transcript_id: str, status: TranscriptStatus): async with self.lock_transaction():