mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Mixdown with pyav filter graph
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import audioop
|
|
||||||
import io
|
import io
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import boto3
|
import boto3
|
||||||
@@ -56,6 +56,138 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||||
|
|
||||||
|
async def mixdown_tracks(
|
||||||
|
self, track_datas: list[bytes], writer: AudioFileWriterProcessor
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Minimal multi-track mixdown using a PyAV filter graph (amix), no resampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Discover target sample rate from first decodable frame
|
||||||
|
target_sample_rate: int | None = None
|
||||||
|
for data in track_datas:
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
container = av.open(io.BytesIO(data))
|
||||||
|
try:
|
||||||
|
for frame in container.decode(audio=0):
|
||||||
|
target_sample_rate = frame.sample_rate
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if target_sample_rate:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not target_sample_rate:
|
||||||
|
self.logger.warning("Mixdown skipped - no decodable audio frames found")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build PyAV filter graph: N abuffer (s32/stereo) -> amix (s32) -> aformat(s16) -> sink
|
||||||
|
graph = av.filter.Graph()
|
||||||
|
inputs = []
|
||||||
|
for idx, data in enumerate([d for d in track_datas if d]):
|
||||||
|
args = (
|
||||||
|
f"time_base=1/{target_sample_rate}:"
|
||||||
|
f"sample_rate={target_sample_rate}:"
|
||||||
|
f"sample_fmt=s32:"
|
||||||
|
f"channel_layout=stereo"
|
||||||
|
)
|
||||||
|
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
|
||||||
|
inputs.append(in_ctx)
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
self.logger.warning("Mixdown skipped - no valid inputs for graph")
|
||||||
|
return
|
||||||
|
|
||||||
|
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
|
||||||
|
|
||||||
|
fmt = graph.add(
|
||||||
|
"aformat",
|
||||||
|
args=(
|
||||||
|
f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}"
|
||||||
|
),
|
||||||
|
name="fmt",
|
||||||
|
)
|
||||||
|
|
||||||
|
sink = graph.add("abuffersink", name="out")
|
||||||
|
|
||||||
|
for idx, in_ctx in enumerate(inputs):
|
||||||
|
in_ctx.link_to(mixer, 0, idx)
|
||||||
|
mixer.link_to(fmt)
|
||||||
|
fmt.link_to(sink)
|
||||||
|
graph.configure()
|
||||||
|
|
||||||
|
# Open containers for decoding
|
||||||
|
containers = []
|
||||||
|
for i, d in enumerate([d for d in track_datas if d]):
|
||||||
|
try:
|
||||||
|
c = av.open(io.BytesIO(d))
|
||||||
|
containers.append(c)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(
|
||||||
|
"Mixdown: failed to open container", input=i, error=str(e)
|
||||||
|
)
|
||||||
|
containers.append(None)
|
||||||
|
# Filter out Nones for decoders
|
||||||
|
containers = [c for c in containers if c is not None]
|
||||||
|
decoders = [c.decode(audio=0) for c in containers]
|
||||||
|
active = [True] * len(decoders)
|
||||||
|
# Per-input resamplers to enforce s32/stereo at the same rate (no resample of rate)
|
||||||
|
resamplers = [
|
||||||
|
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
|
||||||
|
for _ in decoders
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Round-robin feed frames into graph, pull mixed frames as they become available
|
||||||
|
while any(active):
|
||||||
|
for i, (dec, is_active) in enumerate(zip(decoders, active)):
|
||||||
|
if not is_active:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
frame = next(dec)
|
||||||
|
except StopIteration:
|
||||||
|
active[i] = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Enforce same sample rate; convert format/layout to s16/stereo (no resample)
|
||||||
|
if frame.sample_rate != target_sample_rate:
|
||||||
|
# Skip frames with differing rate
|
||||||
|
continue
|
||||||
|
out_frames = resamplers[i].resample(frame) or []
|
||||||
|
for rf in out_frames:
|
||||||
|
rf.sample_rate = target_sample_rate
|
||||||
|
rf.time_base = Fraction(1, target_sample_rate)
|
||||||
|
inputs[i].push(rf)
|
||||||
|
|
||||||
|
# Drain available mixed frames
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
mixed = sink.pull()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
mixed.sample_rate = target_sample_rate
|
||||||
|
mixed.time_base = Fraction(1, target_sample_rate)
|
||||||
|
await writer.push(mixed)
|
||||||
|
|
||||||
|
# Signal EOF to inputs and drain remaining
|
||||||
|
for in_ctx in inputs:
|
||||||
|
in_ctx.push(None)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
mixed = sink.pull()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
mixed.sample_rate = target_sample_rate
|
||||||
|
mixed.time_base = Fraction(1, target_sample_rate)
|
||||||
|
await writer.push(mixed)
|
||||||
|
finally:
|
||||||
|
for c in containers:
|
||||||
|
c.close()
|
||||||
|
|
||||||
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||||
async with self.lock_transaction():
|
async with self.lock_transaction():
|
||||||
return await transcripts_controller.set_status(transcript_id, status)
|
return await transcripts_controller.set_status(transcript_id, status)
|
||||||
@@ -119,83 +251,15 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
)
|
)
|
||||||
track_datas.append(b"")
|
track_datas.append(b"")
|
||||||
|
|
||||||
# Mixdown all available tracks into transcript.audio_mp3_filename at 16kHz mono
|
# Mixdown all available tracks into transcript.audio_mp3_filename, preserving sample rate
|
||||||
try:
|
try:
|
||||||
mp3_writer = AudioFileWriterProcessor(
|
mp3_writer = AudioFileWriterProcessor(
|
||||||
path=str(transcript.audio_mp3_filename)
|
path=str(transcript.audio_mp3_filename)
|
||||||
)
|
)
|
||||||
|
await self.mixdown_tracks(track_datas, mp3_writer)
|
||||||
# Generators for PCM s16 mono 16kHz per track
|
|
||||||
def pcm_generator(data: bytes):
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
container = av.open(io.BytesIO(data))
|
|
||||||
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
|
|
||||||
try:
|
|
||||||
for frame in container.decode(audio=0):
|
|
||||||
rframes = resampler.resample(frame) or []
|
|
||||||
for rf in rframes:
|
|
||||||
# Convert audio plane to raw bytes (PyAV plane supports bytes())
|
|
||||||
yield bytes(rf.planes[0])
|
|
||||||
finally:
|
|
||||||
container.close()
|
|
||||||
|
|
||||||
gens = [pcm_generator(d) for d in track_datas if d]
|
|
||||||
buffers = [bytearray() for _ in gens]
|
|
||||||
active = [True for _ in gens]
|
|
||||||
|
|
||||||
CHUNK_SAMPLES = 16000 # 1 second
|
|
||||||
CHUNK_BYTES = CHUNK_SAMPLES * 2 # s16 mono
|
|
||||||
|
|
||||||
while any(active) or any(len(b) > 0 for b in buffers):
|
|
||||||
# Fill buffers up to CHUNK_BYTES
|
|
||||||
for i, (gen, buf, is_active) in enumerate(zip(gens, buffers, active)):
|
|
||||||
if not is_active:
|
|
||||||
continue
|
|
||||||
while len(buf) < CHUNK_BYTES:
|
|
||||||
try:
|
|
||||||
next_bytes = next(gen)
|
|
||||||
buf.extend(next_bytes)
|
|
||||||
except StopIteration:
|
|
||||||
active[i] = False
|
|
||||||
break
|
|
||||||
|
|
||||||
available_lengths = [len(b) for b in buffers if len(b) > 0]
|
|
||||||
if not available_lengths and not any(active):
|
|
||||||
break
|
|
||||||
if not available_lengths:
|
|
||||||
continue
|
|
||||||
chunk_len = min(min(available_lengths), CHUNK_BYTES)
|
|
||||||
chunk_len -= chunk_len % 2
|
|
||||||
if chunk_len == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Mix: scale each track by 1/N then sum
|
|
||||||
num_sources = max(1, sum(1 for b in buffers if len(b) >= chunk_len))
|
|
||||||
mixed = bytes(chunk_len)
|
|
||||||
for buf in buffers:
|
|
||||||
if len(buf) >= chunk_len:
|
|
||||||
part = bytes(buf[:chunk_len])
|
|
||||||
del buf[:chunk_len]
|
|
||||||
else:
|
|
||||||
if len(buf) == 0:
|
|
||||||
continue
|
|
||||||
part = bytes(buf)
|
|
||||||
del buf[:]
|
|
||||||
part = part + bytes(chunk_len - len(part))
|
|
||||||
scaled = audioop.mul(part, 2, 1.0 / num_sources)
|
|
||||||
mixed = audioop.add(mixed, scaled, 2)
|
|
||||||
|
|
||||||
# Encode mixed frame to MP3
|
|
||||||
num_samples = chunk_len // 2
|
|
||||||
frame = av.AudioFrame(format="s16", layout="mono", samples=num_samples)
|
|
||||||
frame.sample_rate = 16000
|
|
||||||
frame.planes[0].update(mixed)
|
|
||||||
await mp3_writer.push(frame)
|
|
||||||
|
|
||||||
await mp3_writer.flush()
|
await mp3_writer.flush()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning("Mixdown failed", error=str(e))
|
self.logger.error("Mixdown failed", error=str(e))
|
||||||
|
|
||||||
speaker_transcripts: list[TranscriptType] = []
|
speaker_transcripts: list[TranscriptType] = []
|
||||||
for idx, key in enumerate(keys):
|
for idx, key in enumerate(keys):
|
||||||
@@ -205,7 +269,7 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
obj = s3.get_object(Bucket=bucket_name, Key=key)
|
obj = s3.get_object(Bucket=bucket_name, Key=key)
|
||||||
data = obj["Body"].read()
|
data = obj["Body"].read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(
|
self.logger.error(
|
||||||
"Skipping track - cannot read S3 object", key=key, error=str(e)
|
"Skipping track - cannot read S3 object", key=key, error=str(e)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -215,7 +279,7 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
await storage.put_file(storage_path, data)
|
await storage.put_file(storage_path, data)
|
||||||
audio_url = await storage.get_file_url(storage_path)
|
audio_url = await storage.get_file_url(storage_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(
|
self.logger.error(
|
||||||
"Skipping track - cannot upload to storage", key=key, error=str(e)
|
"Skipping track - cannot upload to storage", key=key, error=str(e)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -223,7 +287,7 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
try:
|
try:
|
||||||
t = await self.transcribe_file(audio_url, transcript.source_language)
|
t = await self.transcribe_file(audio_url, transcript.source_language)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(
|
self.logger.error(
|
||||||
"Transcription via default backend failed, trying local whisper",
|
"Transcription via default backend failed, trying local whisper",
|
||||||
key=key,
|
key=key,
|
||||||
url=audio_url,
|
url=audio_url,
|
||||||
@@ -248,7 +312,7 @@ class PipelineMainMultitrack(PipelineMainBase):
|
|||||||
raise Exception("No transcript captured in fallback")
|
raise Exception("No transcript captured in fallback")
|
||||||
t = result
|
t = result
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
self.logger.warning(
|
self.logger.error(
|
||||||
"Skipping track - transcription failed after fallback",
|
"Skipping track - transcription failed after fallback",
|
||||||
key=key,
|
key=key,
|
||||||
url=audio_url,
|
url=audio_url,
|
||||||
|
|||||||
Reference in New Issue
Block a user