From 7030e0f23649a8cf6c1eb6d5889684a41ce849ec Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 27 Aug 2025 10:32:04 -0600 Subject: [PATCH] fix: optimize parakeet transcription batching algorithm (#577) * refactor: optimize transcription batching to accumulate speech segments - Changed VAD segment generator to return full audio array instead of segments - Removed segment filtering step - Modified batch_segments to accumulate maximum speech including silence - Transcribe larger continuous chunks instead of individual speech segments * fix: correct transcribe_batch call to use list and fix batch unpacking * fix: simplify * fix: remove unused variables * fix: add typing --- .../reflector_transcriber_parakeet.py | 129 +++++++----------- 1 file changed, 53 insertions(+), 76 deletions(-) diff --git a/server/gpu/modal_deployments/reflector_transcriber_parakeet.py b/server/gpu/modal_deployments/reflector_transcriber_parakeet.py index 97e150e3..3b6f6ad0 100644 --- a/server/gpu/modal_deployments/reflector_transcriber_parakeet.py +++ b/server/gpu/modal_deployments/reflector_transcriber_parakeet.py @@ -3,7 +3,7 @@ import os import sys import threading import uuid -from typing import Mapping, NewType +from typing import Generator, Mapping, NewType from urllib.parse import urlparse import modal @@ -14,10 +14,7 @@ SAMPLERATE = 16000 UPLOADS_PATH = "/uploads" CACHE_PATH = "/cache" VAD_CONFIG = { - "max_segment_duration": 30.0, - "batch_max_files": 10, - "batch_max_duration": 5.0, - "min_segment_duration": 0.02, + "batch_max_duration": 30.0, "silence_padding": 0.5, "window_size": 512, } @@ -271,7 +268,9 @@ class TranscriberParakeetFile: audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True) return audio_array - def vad_segment_generator(audio_array): + def vad_segment_generator( + audio_array, + ) -> Generator[tuple[float, float], None, None]: """Generate speech segments using VAD with start/end sample indices""" vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE) window_size = VAD_CONFIG["window_size"] @@ -297,76 +296,65 @@ class TranscriberParakeetFile: start_time = start / float(SAMPLERATE) end_time = end / float(SAMPLERATE) - # Extract the actual audio segment - audio_segment = audio_array[start:end] - - yield (start_time, end_time, audio_segment) + yield (start_time, end_time) start = None vad_iterator.reset_states() - def vad_segment_filter(segments): - """Filter VAD segments by duration and chunk large segments""" - min_dur = VAD_CONFIG["min_segment_duration"] - max_dur = VAD_CONFIG["max_segment_duration"] + def batch_speech_segments( + segments: Generator[tuple[float, float], None, None], max_duration: int + ) -> Generator[tuple[float, float], None, None]: + """ + Input segments: + [0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22] - for start_time, end_time, audio_segment in segments: - segment_duration = end_time - start_time + ↓ (max_duration=10) - # Skip very small segments - if segment_duration < min_dur: + Output batches: + [0-8] [10-19] [20-22] + + Note: silences are kept for better transcription, previous implementation was + passing segments separatly, but the output was less accurate. + """ + batch_start_time = None + batch_end_time = None + + for start_time, end_time in segments: + if batch_start_time is None or batch_end_time is None: + batch_start_time = start_time + batch_end_time = end_time continue - # If segment is within max duration, yield as-is - if segment_duration <= max_dur: - yield (start_time, end_time, audio_segment) + total_duration = end_time - batch_start_time + + if total_duration <= max_duration: + batch_end_time = end_time continue - # Chunk large segments into smaller pieces - chunk_samples = int(max_dur * SAMPLERATE) - current_start = start_time + yield (batch_start_time, batch_end_time) + batch_start_time = start_time + batch_end_time = end_time - for chunk_offset in range(0, len(audio_segment), chunk_samples): - chunk_audio = audio_segment[ - chunk_offset : chunk_offset + chunk_samples - ] - if len(chunk_audio) == 0: - break + if batch_start_time is None or batch_end_time is None: + return - chunk_duration = len(chunk_audio) / float(SAMPLERATE) - chunk_end = current_start + chunk_duration + yield (batch_start_time, batch_end_time) - # Only yield chunks that meet minimum duration - if chunk_duration >= min_dur: - yield (current_start, chunk_end, chunk_audio) + def batch_segment_to_audio_segment(segments, audio_array): + for start_time, end_time in segments: + start_sample = int(start_time * SAMPLERATE) + end_sample = int(end_time * SAMPLERATE) + audio_segment = audio_array[start_sample:end_sample] - current_start = chunk_end - - def batch_segments(segments, max_files=10, max_duration=5.0): - batch = [] - batch_duration = 0.0 - - for start_time, end_time, audio_segment in segments: - segment_duration = end_time - start_time - - if segment_duration < VAD_CONFIG["silence_padding"]: + if end_time - start_time < VAD_CONFIG["silence_padding"]: silence_samples = int( - (VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE + (VAD_CONFIG["silence_padding"] - (end_time - start_time)) + * SAMPLERATE ) padding = np.zeros(silence_samples, dtype=np.float32) audio_segment = np.concatenate([audio_segment, padding]) - segment_duration = VAD_CONFIG["silence_padding"] - batch.append((start_time, end_time, audio_segment)) - batch_duration += segment_duration - - if len(batch) >= max_files or batch_duration >= max_duration: - yield batch - batch = [] - batch_duration = 0.0 - - if batch: - yield batch + yield start_time, end_time, audio_segment def transcribe_batch(model, audio_segments): with NoStdStreams(): @@ -376,8 +364,6 @@ class TranscriberParakeetFile: def emit_results( results, segments_info, - batch_index, - total_batches, ): """Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions.""" for i, (output, (start_time, end_time, _)) in enumerate( @@ -413,35 +399,26 @@ class TranscriberParakeetFile: all_words = [] raw_segments = vad_segment_generator(audio_array) - filtered_segments = vad_segment_filter(raw_segments) - batches = batch_segments( - filtered_segments, - VAD_CONFIG["batch_max_files"], + speech_segments = batch_speech_segments( + raw_segments, VAD_CONFIG["batch_max_duration"], ) + audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array) - batch_index = 0 - total_batches = max( - 1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1 - ) - - for batch in batches: - batch_index += 1 - audio_segments = [seg[2] for seg in batch] - results = transcribe_batch(self.model, audio_segments) + for batch in audio_segments: + _, _, audio_segment = batch + results = transcribe_batch(self.model, [audio_segment]) for text, words in emit_results( results, - batch, - batch_index, - total_batches, + [batch], ): if not text: continue all_text_parts.append(text) all_words.extend(words) - processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch) + processed_duration += len(audio_segment) / float(SAMPLERATE) combined_text = " ".join(all_text_parts) return {"text": combined_text, "words": all_words}