diff --git a/server/reflector/pipelines/main_multitrack_pipeline.py b/server/reflector/pipelines/main_multitrack_pipeline.py index beb681d2..7cbc1b62 100644 --- a/server/reflector/pipelines/main_multitrack_pipeline.py +++ b/server/reflector/pipelines/main_multitrack_pipeline.py @@ -1,7 +1,11 @@ import asyncio +import audioop +import io +import av import boto3 import structlog +from av.audio.resampler import AudioResampler from celery import chain, shared_task from reflector.asynctask import asynctask @@ -18,6 +22,7 @@ from reflector.pipelines.main_live_pipeline import ( task_pipeline_post_to_zulip, ) from reflector.processors import ( + AudioFileWriterProcessor, TranscriptFinalSummaryProcessor, TranscriptFinalTitleProcessor, TranscriptTopicDetectorProcessor, @@ -102,6 +107,96 @@ class PipelineMainMultitrack(PipelineMainBase): storage = get_transcripts_storage() + # Pre-download bytes for all tracks for mixing and transcription + track_datas: list[bytes] = [] + for key in keys: + try: + obj = s3.get_object(Bucket=bucket_name, Key=key) + track_datas.append(obj["Body"].read()) + except Exception as e: + self.logger.warning( + "Skipping track - cannot read S3 object", key=key, error=str(e) + ) + track_datas.append(b"") + + # Mixdown all available tracks into transcript.audio_mp3_filename at 16kHz mono + try: + mp3_writer = AudioFileWriterProcessor( + path=str(transcript.audio_mp3_filename) + ) + + # Generators for PCM s16 mono 16kHz per track + def pcm_generator(data: bytes): + if not data: + return + container = av.open(io.BytesIO(data)) + resampler = AudioResampler(format="s16", layout="mono", rate=16000) + try: + for frame in container.decode(audio=0): + rframes = resampler.resample(frame) or [] + for rf in rframes: + # Convert audio plane to raw bytes (PyAV plane supports bytes()) + yield bytes(rf.planes[0]) + finally: + container.close() + + gens = [pcm_generator(d) for d in track_datas if d] + buffers = [bytearray() for _ in gens] + active = [True for _ in gens] + + CHUNK_SAMPLES = 16000 # 1 second + CHUNK_BYTES = CHUNK_SAMPLES * 2 # s16 mono + + while any(active) or any(len(b) > 0 for b in buffers): + # Fill buffers up to CHUNK_BYTES + for i, (gen, buf, is_active) in enumerate(zip(gens, buffers, active)): + if not is_active: + continue + while len(buf) < CHUNK_BYTES: + try: + next_bytes = next(gen) + buf.extend(next_bytes) + except StopIteration: + active[i] = False + break + + available_lengths = [len(b) for b in buffers if len(b) > 0] + if not available_lengths and not any(active): + break + if not available_lengths: + continue + chunk_len = min(min(available_lengths), CHUNK_BYTES) + chunk_len -= chunk_len % 2 + if chunk_len == 0: + continue + + # Mix: scale each track by 1/N then sum + num_sources = max(1, sum(1 for b in buffers if len(b) >= chunk_len)) + mixed = bytes(chunk_len) + for buf in buffers: + if len(buf) >= chunk_len: + part = bytes(buf[:chunk_len]) + del buf[:chunk_len] + else: + if len(buf) == 0: + continue + part = bytes(buf) + del buf[:] + part = part + bytes(chunk_len - len(part)) + scaled = audioop.mul(part, 2, 1.0 / num_sources) + mixed = audioop.add(mixed, scaled, 2) + + # Encode mixed frame to MP3 + num_samples = chunk_len // 2 + frame = av.AudioFrame(format="s16", layout="mono", samples=num_samples) + frame.sample_rate = 16000 + frame.planes[0].update(mixed) + await mp3_writer.push(frame) + + await mp3_writer.flush() + except Exception as e: + self.logger.warning("Mixdown failed", error=str(e)) + speaker_transcripts: list[TranscriptType] = [] for idx, key in enumerate(keys): ext = ".mp4"