style: more type annotations to parakeet transcriber (#581)

* feat: add comprehensive type annotations to Parakeet transcriber

- Add TypedDict for WordTiming with word, start, end fields
- Add NamedTuple for TimeSegment, AudioSegment, and TranscriptResult
- Add type hints to all generator functions (vad_segment_generator, batch_speech_segments, etc.)
- Add enforce_word_timing_constraints function to prevent word timing overlaps
- Refactor batch_segment_to_audio_segment to reuse pad_audio function

* doc: add note about space
This commit is contained in:
2025-08-28 12:22:07 -06:00
committed by GitHub
parent 124ce03bf8
commit f5331a2107

View File

@@ -3,7 +3,7 @@ import os
import sys import sys
import threading import threading
import uuid import uuid
from typing import Generator, Mapping, NewType from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
from urllib.parse import urlparse from urllib.parse import urlparse
import modal import modal
@@ -22,6 +22,37 @@ VAD_CONFIG = {
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str) ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
AudioFileExtension = NewType("AudioFileExtension", str) AudioFileExtension = NewType("AudioFileExtension", str)
class TimeSegment(NamedTuple):
"""Represents a time segment with start and end times."""
start: float
end: float
class AudioSegment(NamedTuple):
"""Represents an audio segment with timing and audio data."""
start: float
end: float
audio: any
class TranscriptResult(NamedTuple):
"""Represents a transcription result with text and word timings."""
text: str
words: list["WordTiming"]
class WordTiming(TypedDict):
"""Represents a word with its timing information."""
word: str
start: float
end: float
app = modal.App("reflector-transcriber-parakeet") app = modal.App("reflector-transcriber-parakeet")
# Volume for caching model weights # Volume for caching model weights
@@ -167,12 +198,14 @@ class TranscriberParakeetLive:
(output,) = self.model.transcribe([padded_audio], timestamps=True) (output,) = self.model.transcribe([padded_audio], timestamps=True)
text = output.text.strip() text = output.text.strip()
words = [ words: list[WordTiming] = [
{ WordTiming(
"word": word_info["word"] + " ", # XXX the space added here is to match the output of whisper
"start": round(word_info["start"], 2), # whisper add space to each words, while parakeet don't
"end": round(word_info["end"], 2), word=word_info["word"] + " ",
} start=round(word_info["start"], 2),
end=round(word_info["end"], 2),
)
for word_info in output.timestamp["word"] for word_info in output.timestamp["word"]
] ]
@@ -208,12 +241,12 @@ class TranscriberParakeetLive:
for i, (filename, output) in enumerate(zip(filenames, outputs)): for i, (filename, output) in enumerate(zip(filenames, outputs)):
text = output.text.strip() text = output.text.strip()
words = [ words: list[WordTiming] = [
{ WordTiming(
"word": word_info["word"] + " ", word=word_info["word"] + " ",
"start": round(word_info["start"], 2), start=round(word_info["start"], 2),
"end": round(word_info["end"], 2), end=round(word_info["end"], 2),
} )
for word_info in output.timestamp["word"] for word_info in output.timestamp["word"]
] ]
@@ -270,7 +303,7 @@ class TranscriberParakeetFile:
def vad_segment_generator( def vad_segment_generator(
audio_array, audio_array,
) -> Generator[tuple[float, float], None, None]: ) -> Generator[TimeSegment, None, None]:
"""Generate speech segments using VAD with start/end sample indices""" """Generate speech segments using VAD with start/end sample indices"""
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE) vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
window_size = VAD_CONFIG["window_size"] window_size = VAD_CONFIG["window_size"]
@@ -296,14 +329,14 @@ class TranscriberParakeetFile:
start_time = start / float(SAMPLERATE) start_time = start / float(SAMPLERATE)
end_time = end / float(SAMPLERATE) end_time = end / float(SAMPLERATE)
yield (start_time, end_time) yield TimeSegment(start_time, end_time)
start = None start = None
vad_iterator.reset_states() vad_iterator.reset_states()
def batch_speech_segments( def batch_speech_segments(
segments: Generator[tuple[float, float], None, None], max_duration: int segments: Generator[TimeSegment, None, None], max_duration: int
) -> Generator[tuple[float, float], None, None]: ) -> Generator[TimeSegment, None, None]:
""" """
Input segments: Input segments:
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22] [0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
@@ -319,7 +352,8 @@ class TranscriberParakeetFile:
batch_start_time = None batch_start_time = None
batch_end_time = None batch_end_time = None
for start_time, end_time in segments: for segment in segments:
start_time, end_time = segment.start, segment.end
if batch_start_time is None or batch_end_time is None: if batch_start_time is None or batch_end_time is None:
batch_start_time = start_time batch_start_time = start_time
batch_end_time = end_time batch_end_time = end_time
@@ -331,59 +365,85 @@ class TranscriberParakeetFile:
batch_end_time = end_time batch_end_time = end_time
continue continue
yield (batch_start_time, batch_end_time) yield TimeSegment(batch_start_time, batch_end_time)
batch_start_time = start_time batch_start_time = start_time
batch_end_time = end_time batch_end_time = end_time
if batch_start_time is None or batch_end_time is None: if batch_start_time is None or batch_end_time is None:
return return
yield (batch_start_time, batch_end_time) yield TimeSegment(batch_start_time, batch_end_time)
def batch_segment_to_audio_segment(segments, audio_array): def batch_segment_to_audio_segment(
for start_time, end_time in segments: segments: Generator[TimeSegment, None, None],
audio_array,
) -> Generator[AudioSegment, None, None]:
"""Extract audio segments and apply padding for Parakeet compatibility.
Uses pad_audio to ensure segments are at least 0.5s long, preventing
Parakeet crashes. This padding may cause slight timing overlaps between
segments, which are corrected by enforce_word_timing_constraints.
"""
for segment in segments:
start_time, end_time = segment.start, segment.end
start_sample = int(start_time * SAMPLERATE) start_sample = int(start_time * SAMPLERATE)
end_sample = int(end_time * SAMPLERATE) end_sample = int(end_time * SAMPLERATE)
audio_segment = audio_array[start_sample:end_sample] audio_segment = audio_array[start_sample:end_sample]
if end_time - start_time < VAD_CONFIG["silence_padding"]: padded_segment = pad_audio(audio_segment, SAMPLERATE)
silence_samples = int(
(VAD_CONFIG["silence_padding"] - (end_time - start_time))
* SAMPLERATE
)
padding = np.zeros(silence_samples, dtype=np.float32)
audio_segment = np.concatenate([audio_segment, padding])
yield start_time, end_time, audio_segment yield AudioSegment(start_time, end_time, padded_segment)
def transcribe_batch(model, audio_segments): def transcribe_batch(model, audio_segments: list) -> list:
with NoStdStreams(): with NoStdStreams():
outputs = model.transcribe(audio_segments, timestamps=True) outputs = model.transcribe(audio_segments, timestamps=True)
return outputs return outputs
def enforce_word_timing_constraints(
words: list[WordTiming],
) -> list[WordTiming]:
"""Enforce that word end times don't exceed the start time of the next word.
Due to silence padding added in batch_segment_to_audio_segment for better
transcription accuracy, word timings from different segments may overlap.
This function ensures there are no overlaps by adjusting end times.
"""
if len(words) <= 1:
return words
enforced_words = []
for i, word in enumerate(words):
enforced_word = word.copy()
if i < len(words) - 1:
next_start = words[i + 1]["start"]
if enforced_word["end"] > next_start:
enforced_word["end"] = next_start
enforced_words.append(enforced_word)
return enforced_words
def emit_results( def emit_results(
results, results: list,
segments_info, segments_info: list[AudioSegment],
): ) -> Generator[TranscriptResult, None, None]:
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions.""" """Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
for i, (output, (start_time, end_time, _)) in enumerate( for i, (output, segment) in enumerate(zip(results, segments_info)):
zip(results, segments_info) start_time, end_time = segment.start, segment.end
):
text = output.text.strip() text = output.text.strip()
words = [ words: list[WordTiming] = [
{ WordTiming(
"word": word_info["word"] + " ", word=word_info["word"] + " ",
"start": round( start=round(
word_info["start"] + start_time + timestamp_offset, 2 word_info["start"] + start_time + timestamp_offset, 2
), ),
"end": round( end=round(word_info["end"] + start_time + timestamp_offset, 2),
word_info["end"] + start_time + timestamp_offset, 2 )
),
}
for word_info in output.timestamp["word"] for word_info in output.timestamp["word"]
] ]
yield text, words yield TranscriptResult(text, words)
upload_volume.reload() upload_volume.reload()
@@ -393,10 +453,9 @@ class TranscriberParakeetFile:
audio_array = load_and_convert_audio(file_path) audio_array = load_and_convert_audio(file_path)
total_duration = len(audio_array) / float(SAMPLERATE) total_duration = len(audio_array) / float(SAMPLERATE)
processed_duration = 0.0
all_text_parts = [] all_text_parts: list[str] = []
all_words = [] all_words: list[WordTiming] = []
raw_segments = vad_segment_generator(audio_array) raw_segments = vad_segment_generator(audio_array)
speech_segments = batch_speech_segments( speech_segments = batch_speech_segments(
@@ -406,19 +465,19 @@ class TranscriberParakeetFile:
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array) audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
for batch in audio_segments: for batch in audio_segments:
_, _, audio_segment = batch audio_segment = batch.audio
results = transcribe_batch(self.model, [audio_segment]) results = transcribe_batch(self.model, [audio_segment])
for text, words in emit_results( for result in emit_results(
results, results,
[batch], [batch],
): ):
if not text: if not result.text:
continue continue
all_text_parts.append(text) all_text_parts.append(result.text)
all_words.extend(words) all_words.extend(result.words)
processed_duration += len(audio_segment) / float(SAMPLERATE) all_words = enforce_word_timing_constraints(all_words)
combined_text = " ".join(all_text_parts) combined_text = " ".join(all_text_parts)
return {"text": combined_text, "words": all_words} return {"text": combined_text, "words": all_words}