fix: optimize parakeet transcription batching algorithm (#577)

* refactor: optimize transcription batching to accumulate speech segments

- Changed VAD segment generator to return full audio array instead of segments
- Removed segment filtering step
- Modified batch_segments to accumulate maximum speech including silence
- Transcribe larger continuous chunks instead of individual speech segments

* fix: correct transcribe_batch call to use list and fix batch unpacking

* fix: simplify

* fix: remove unused variables

* fix: add typing
This commit is contained in:
2025-08-27 10:32:04 -06:00
committed by GitHub
parent 37f0110892
commit 7030e0f236

View File

@@ -3,7 +3,7 @@ import os
import sys import sys
import threading import threading
import uuid import uuid
from typing import Mapping, NewType from typing import Generator, Mapping, NewType
from urllib.parse import urlparse from urllib.parse import urlparse
import modal import modal
@@ -14,10 +14,7 @@ SAMPLERATE = 16000
UPLOADS_PATH = "/uploads" UPLOADS_PATH = "/uploads"
CACHE_PATH = "/cache" CACHE_PATH = "/cache"
VAD_CONFIG = { VAD_CONFIG = {
"max_segment_duration": 30.0, "batch_max_duration": 30.0,
"batch_max_files": 10,
"batch_max_duration": 5.0,
"min_segment_duration": 0.02,
"silence_padding": 0.5, "silence_padding": 0.5,
"window_size": 512, "window_size": 512,
} }
@@ -271,7 +268,9 @@ class TranscriberParakeetFile:
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True) audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
return audio_array return audio_array
def vad_segment_generator(audio_array): def vad_segment_generator(
audio_array,
) -> Generator[tuple[float, float], None, None]:
"""Generate speech segments using VAD with start/end sample indices""" """Generate speech segments using VAD with start/end sample indices"""
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE) vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
window_size = VAD_CONFIG["window_size"] window_size = VAD_CONFIG["window_size"]
@@ -297,76 +296,65 @@ class TranscriberParakeetFile:
start_time = start / float(SAMPLERATE) start_time = start / float(SAMPLERATE)
end_time = end / float(SAMPLERATE) end_time = end / float(SAMPLERATE)
# Extract the actual audio segment yield (start_time, end_time)
audio_segment = audio_array[start:end]
yield (start_time, end_time, audio_segment)
start = None start = None
vad_iterator.reset_states() vad_iterator.reset_states()
def vad_segment_filter(segments): def batch_speech_segments(
"""Filter VAD segments by duration and chunk large segments""" segments: Generator[tuple[float, float], None, None], max_duration: int
min_dur = VAD_CONFIG["min_segment_duration"] ) -> Generator[tuple[float, float], None, None]:
max_dur = VAD_CONFIG["max_segment_duration"] """
Input segments:
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
for start_time, end_time, audio_segment in segments: ↓ (max_duration=10)
segment_duration = end_time - start_time
# Skip very small segments Output batches:
if segment_duration < min_dur: [0-8] [10-19] [20-22]
Note: silences are kept for better transcription, previous implementation was
passing segments separatly, but the output was less accurate.
"""
batch_start_time = None
batch_end_time = None
for start_time, end_time in segments:
if batch_start_time is None or batch_end_time is None:
batch_start_time = start_time
batch_end_time = end_time
continue continue
# If segment is within max duration, yield as-is total_duration = end_time - batch_start_time
if segment_duration <= max_dur:
yield (start_time, end_time, audio_segment) if total_duration <= max_duration:
batch_end_time = end_time
continue continue
# Chunk large segments into smaller pieces yield (batch_start_time, batch_end_time)
chunk_samples = int(max_dur * SAMPLERATE) batch_start_time = start_time
current_start = start_time batch_end_time = end_time
for chunk_offset in range(0, len(audio_segment), chunk_samples): if batch_start_time is None or batch_end_time is None:
chunk_audio = audio_segment[ return
chunk_offset : chunk_offset + chunk_samples
]
if len(chunk_audio) == 0:
break
chunk_duration = len(chunk_audio) / float(SAMPLERATE) yield (batch_start_time, batch_end_time)
chunk_end = current_start + chunk_duration
# Only yield chunks that meet minimum duration def batch_segment_to_audio_segment(segments, audio_array):
if chunk_duration >= min_dur: for start_time, end_time in segments:
yield (current_start, chunk_end, chunk_audio) start_sample = int(start_time * SAMPLERATE)
end_sample = int(end_time * SAMPLERATE)
audio_segment = audio_array[start_sample:end_sample]
current_start = chunk_end if end_time - start_time < VAD_CONFIG["silence_padding"]:
def batch_segments(segments, max_files=10, max_duration=5.0):
batch = []
batch_duration = 0.0
for start_time, end_time, audio_segment in segments:
segment_duration = end_time - start_time
if segment_duration < VAD_CONFIG["silence_padding"]:
silence_samples = int( silence_samples = int(
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE (VAD_CONFIG["silence_padding"] - (end_time - start_time))
* SAMPLERATE
) )
padding = np.zeros(silence_samples, dtype=np.float32) padding = np.zeros(silence_samples, dtype=np.float32)
audio_segment = np.concatenate([audio_segment, padding]) audio_segment = np.concatenate([audio_segment, padding])
segment_duration = VAD_CONFIG["silence_padding"]
batch.append((start_time, end_time, audio_segment)) yield start_time, end_time, audio_segment
batch_duration += segment_duration
if len(batch) >= max_files or batch_duration >= max_duration:
yield batch
batch = []
batch_duration = 0.0
if batch:
yield batch
def transcribe_batch(model, audio_segments): def transcribe_batch(model, audio_segments):
with NoStdStreams(): with NoStdStreams():
@@ -376,8 +364,6 @@ class TranscriberParakeetFile:
def emit_results( def emit_results(
results, results,
segments_info, segments_info,
batch_index,
total_batches,
): ):
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions.""" """Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
for i, (output, (start_time, end_time, _)) in enumerate( for i, (output, (start_time, end_time, _)) in enumerate(
@@ -413,35 +399,26 @@ class TranscriberParakeetFile:
all_words = [] all_words = []
raw_segments = vad_segment_generator(audio_array) raw_segments = vad_segment_generator(audio_array)
filtered_segments = vad_segment_filter(raw_segments) speech_segments = batch_speech_segments(
batches = batch_segments( raw_segments,
filtered_segments,
VAD_CONFIG["batch_max_files"],
VAD_CONFIG["batch_max_duration"], VAD_CONFIG["batch_max_duration"],
) )
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
batch_index = 0 for batch in audio_segments:
total_batches = max( _, _, audio_segment = batch
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1 results = transcribe_batch(self.model, [audio_segment])
)
for batch in batches:
batch_index += 1
audio_segments = [seg[2] for seg in batch]
results = transcribe_batch(self.model, audio_segments)
for text, words in emit_results( for text, words in emit_results(
results, results,
batch, [batch],
batch_index,
total_batches,
): ):
if not text: if not text:
continue continue
all_text_parts.append(text) all_text_parts.append(text)
all_words.extend(words) all_words.extend(words)
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch) processed_duration += len(audio_segment) / float(SAMPLERATE)
combined_text = " ".join(all_text_parts) combined_text = " ".join(all_text_parts)
return {"text": combined_text, "words": all_words} return {"text": combined_text, "words": all_words}