fix: optimize parakeet transcription batching algorithm (#577)

* refactor: optimize transcription batching to accumulate speech segments

- Changed VAD segment generator to return full audio array instead of segments
- Removed segment filtering step
- Modified batch_segments to accumulate maximum speech including silence
- Transcribe larger continuous chunks instead of individual speech segments

* fix: correct transcribe_batch call to use list and fix batch unpacking

* fix: simplify

* fix: remove unused variables

* fix: add typing
This commit is contained in:
2025-08-27 10:32:04 -06:00
committed by GitHub
parent 37f0110892
commit 7030e0f236

View File

@@ -3,7 +3,7 @@ import os
import sys
import threading
import uuid
from typing import Mapping, NewType
from typing import Generator, Mapping, NewType
from urllib.parse import urlparse
import modal
@@ -14,10 +14,7 @@ SAMPLERATE = 16000
UPLOADS_PATH = "/uploads"
CACHE_PATH = "/cache"
VAD_CONFIG = {
"max_segment_duration": 30.0,
"batch_max_files": 10,
"batch_max_duration": 5.0,
"min_segment_duration": 0.02,
"batch_max_duration": 30.0,
"silence_padding": 0.5,
"window_size": 512,
}
@@ -271,7 +268,9 @@ class TranscriberParakeetFile:
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
return audio_array
def vad_segment_generator(audio_array):
def vad_segment_generator(
audio_array,
) -> Generator[tuple[float, float], None, None]:
"""Generate speech segments using VAD with start/end sample indices"""
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
window_size = VAD_CONFIG["window_size"]
@@ -297,76 +296,65 @@ class TranscriberParakeetFile:
start_time = start / float(SAMPLERATE)
end_time = end / float(SAMPLERATE)
# Extract the actual audio segment
audio_segment = audio_array[start:end]
yield (start_time, end_time, audio_segment)
yield (start_time, end_time)
start = None
vad_iterator.reset_states()
def vad_segment_filter(segments):
"""Filter VAD segments by duration and chunk large segments"""
min_dur = VAD_CONFIG["min_segment_duration"]
max_dur = VAD_CONFIG["max_segment_duration"]
def batch_speech_segments(
segments: Generator[tuple[float, float], None, None], max_duration: int
) -> Generator[tuple[float, float], None, None]:
"""
Input segments:
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
for start_time, end_time, audio_segment in segments:
segment_duration = end_time - start_time
↓ (max_duration=10)
# Skip very small segments
if segment_duration < min_dur:
Output batches:
[0-8] [10-19] [20-22]
Note: silences are kept for better transcription, previous implementation was
passing segments separatly, but the output was less accurate.
"""
batch_start_time = None
batch_end_time = None
for start_time, end_time in segments:
if batch_start_time is None or batch_end_time is None:
batch_start_time = start_time
batch_end_time = end_time
continue
# If segment is within max duration, yield as-is
if segment_duration <= max_dur:
yield (start_time, end_time, audio_segment)
total_duration = end_time - batch_start_time
if total_duration <= max_duration:
batch_end_time = end_time
continue
# Chunk large segments into smaller pieces
chunk_samples = int(max_dur * SAMPLERATE)
current_start = start_time
yield (batch_start_time, batch_end_time)
batch_start_time = start_time
batch_end_time = end_time
for chunk_offset in range(0, len(audio_segment), chunk_samples):
chunk_audio = audio_segment[
chunk_offset : chunk_offset + chunk_samples
]
if len(chunk_audio) == 0:
break
if batch_start_time is None or batch_end_time is None:
return
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
chunk_end = current_start + chunk_duration
yield (batch_start_time, batch_end_time)
# Only yield chunks that meet minimum duration
if chunk_duration >= min_dur:
yield (current_start, chunk_end, chunk_audio)
def batch_segment_to_audio_segment(segments, audio_array):
for start_time, end_time in segments:
start_sample = int(start_time * SAMPLERATE)
end_sample = int(end_time * SAMPLERATE)
audio_segment = audio_array[start_sample:end_sample]
current_start = chunk_end
def batch_segments(segments, max_files=10, max_duration=5.0):
batch = []
batch_duration = 0.0
for start_time, end_time, audio_segment in segments:
segment_duration = end_time - start_time
if segment_duration < VAD_CONFIG["silence_padding"]:
if end_time - start_time < VAD_CONFIG["silence_padding"]:
silence_samples = int(
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
(VAD_CONFIG["silence_padding"] - (end_time - start_time))
* SAMPLERATE
)
padding = np.zeros(silence_samples, dtype=np.float32)
audio_segment = np.concatenate([audio_segment, padding])
segment_duration = VAD_CONFIG["silence_padding"]
batch.append((start_time, end_time, audio_segment))
batch_duration += segment_duration
if len(batch) >= max_files or batch_duration >= max_duration:
yield batch
batch = []
batch_duration = 0.0
if batch:
yield batch
yield start_time, end_time, audio_segment
def transcribe_batch(model, audio_segments):
with NoStdStreams():
@@ -376,8 +364,6 @@ class TranscriberParakeetFile:
def emit_results(
results,
segments_info,
batch_index,
total_batches,
):
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
for i, (output, (start_time, end_time, _)) in enumerate(
@@ -413,35 +399,26 @@ class TranscriberParakeetFile:
all_words = []
raw_segments = vad_segment_generator(audio_array)
filtered_segments = vad_segment_filter(raw_segments)
batches = batch_segments(
filtered_segments,
VAD_CONFIG["batch_max_files"],
speech_segments = batch_speech_segments(
raw_segments,
VAD_CONFIG["batch_max_duration"],
)
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
batch_index = 0
total_batches = max(
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
)
for batch in batches:
batch_index += 1
audio_segments = [seg[2] for seg in batch]
results = transcribe_batch(self.model, audio_segments)
for batch in audio_segments:
_, _, audio_segment = batch
results = transcribe_batch(self.model, [audio_segment])
for text, words in emit_results(
results,
batch,
batch_index,
total_batches,
[batch],
):
if not text:
continue
all_text_parts.append(text)
all_words.extend(words)
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
processed_duration += len(audio_segment) / float(SAMPLERATE)
combined_text = " ".join(all_text_parts)
return {"text": combined_text, "words": all_words}