style: more type annotations to parakeet transcriber (#581)

* feat: add comprehensive type annotations to Parakeet transcriber

- Add TypedDict for WordTiming with word, start, end fields
- Add NamedTuple for TimeSegment, AudioSegment, and TranscriptResult
- Add type hints to all generator functions (vad_segment_generator, batch_speech_segments, etc.)
- Add enforce_word_timing_constraints function to prevent word timing overlaps
- Refactor batch_segment_to_audio_segment to reuse pad_audio function

* doc: add note about space
This commit is contained in:
2025-08-28 12:22:07 -06:00
committed by GitHub
parent 124ce03bf8
commit f5331a2107

View File

@@ -3,7 +3,7 @@ import os
import sys
import threading
import uuid
from typing import Generator, Mapping, NewType
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
from urllib.parse import urlparse
import modal
@@ -22,6 +22,37 @@ VAD_CONFIG = {
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
AudioFileExtension = NewType("AudioFileExtension", str)
class TimeSegment(NamedTuple):
"""Represents a time segment with start and end times."""
start: float
end: float
class AudioSegment(NamedTuple):
"""Represents an audio segment with timing and audio data."""
start: float
end: float
audio: any
class TranscriptResult(NamedTuple):
"""Represents a transcription result with text and word timings."""
text: str
words: list["WordTiming"]
class WordTiming(TypedDict):
"""Represents a word with its timing information."""
word: str
start: float
end: float
app = modal.App("reflector-transcriber-parakeet")
# Volume for caching model weights
@@ -167,12 +198,14 @@ class TranscriberParakeetLive:
(output,) = self.model.transcribe([padded_audio], timestamps=True)
text = output.text.strip()
words = [
{
"word": word_info["word"] + " ",
"start": round(word_info["start"], 2),
"end": round(word_info["end"], 2),
}
words: list[WordTiming] = [
WordTiming(
# XXX the space added here is to match the output of whisper
# whisper add space to each words, while parakeet don't
word=word_info["word"] + " ",
start=round(word_info["start"], 2),
end=round(word_info["end"], 2),
)
for word_info in output.timestamp["word"]
]
@@ -208,12 +241,12 @@ class TranscriberParakeetLive:
for i, (filename, output) in enumerate(zip(filenames, outputs)):
text = output.text.strip()
words = [
{
"word": word_info["word"] + " ",
"start": round(word_info["start"], 2),
"end": round(word_info["end"], 2),
}
words: list[WordTiming] = [
WordTiming(
word=word_info["word"] + " ",
start=round(word_info["start"], 2),
end=round(word_info["end"], 2),
)
for word_info in output.timestamp["word"]
]
@@ -270,7 +303,7 @@ class TranscriberParakeetFile:
def vad_segment_generator(
audio_array,
) -> Generator[tuple[float, float], None, None]:
) -> Generator[TimeSegment, None, None]:
"""Generate speech segments using VAD with start/end sample indices"""
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
window_size = VAD_CONFIG["window_size"]
@@ -296,14 +329,14 @@ class TranscriberParakeetFile:
start_time = start / float(SAMPLERATE)
end_time = end / float(SAMPLERATE)
yield (start_time, end_time)
yield TimeSegment(start_time, end_time)
start = None
vad_iterator.reset_states()
def batch_speech_segments(
segments: Generator[tuple[float, float], None, None], max_duration: int
) -> Generator[tuple[float, float], None, None]:
segments: Generator[TimeSegment, None, None], max_duration: int
) -> Generator[TimeSegment, None, None]:
"""
Input segments:
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
@@ -319,7 +352,8 @@ class TranscriberParakeetFile:
batch_start_time = None
batch_end_time = None
for start_time, end_time in segments:
for segment in segments:
start_time, end_time = segment.start, segment.end
if batch_start_time is None or batch_end_time is None:
batch_start_time = start_time
batch_end_time = end_time
@@ -331,59 +365,85 @@ class TranscriberParakeetFile:
batch_end_time = end_time
continue
yield (batch_start_time, batch_end_time)
yield TimeSegment(batch_start_time, batch_end_time)
batch_start_time = start_time
batch_end_time = end_time
if batch_start_time is None or batch_end_time is None:
return
yield (batch_start_time, batch_end_time)
yield TimeSegment(batch_start_time, batch_end_time)
def batch_segment_to_audio_segment(segments, audio_array):
for start_time, end_time in segments:
def batch_segment_to_audio_segment(
segments: Generator[TimeSegment, None, None],
audio_array,
) -> Generator[AudioSegment, None, None]:
"""Extract audio segments and apply padding for Parakeet compatibility.
Uses pad_audio to ensure segments are at least 0.5s long, preventing
Parakeet crashes. This padding may cause slight timing overlaps between
segments, which are corrected by enforce_word_timing_constraints.
"""
for segment in segments:
start_time, end_time = segment.start, segment.end
start_sample = int(start_time * SAMPLERATE)
end_sample = int(end_time * SAMPLERATE)
audio_segment = audio_array[start_sample:end_sample]
if end_time - start_time < VAD_CONFIG["silence_padding"]:
silence_samples = int(
(VAD_CONFIG["silence_padding"] - (end_time - start_time))
* SAMPLERATE
)
padding = np.zeros(silence_samples, dtype=np.float32)
audio_segment = np.concatenate([audio_segment, padding])
padded_segment = pad_audio(audio_segment, SAMPLERATE)
yield start_time, end_time, audio_segment
yield AudioSegment(start_time, end_time, padded_segment)
def transcribe_batch(model, audio_segments):
def transcribe_batch(model, audio_segments: list) -> list:
with NoStdStreams():
outputs = model.transcribe(audio_segments, timestamps=True)
return outputs
def enforce_word_timing_constraints(
words: list[WordTiming],
) -> list[WordTiming]:
"""Enforce that word end times don't exceed the start time of the next word.
Due to silence padding added in batch_segment_to_audio_segment for better
transcription accuracy, word timings from different segments may overlap.
This function ensures there are no overlaps by adjusting end times.
"""
if len(words) <= 1:
return words
enforced_words = []
for i, word in enumerate(words):
enforced_word = word.copy()
if i < len(words) - 1:
next_start = words[i + 1]["start"]
if enforced_word["end"] > next_start:
enforced_word["end"] = next_start
enforced_words.append(enforced_word)
return enforced_words
def emit_results(
results,
segments_info,
):
results: list,
segments_info: list[AudioSegment],
) -> Generator[TranscriptResult, None, None]:
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
for i, (output, (start_time, end_time, _)) in enumerate(
zip(results, segments_info)
):
for i, (output, segment) in enumerate(zip(results, segments_info)):
start_time, end_time = segment.start, segment.end
text = output.text.strip()
words = [
{
"word": word_info["word"] + " ",
"start": round(
words: list[WordTiming] = [
WordTiming(
word=word_info["word"] + " ",
start=round(
word_info["start"] + start_time + timestamp_offset, 2
),
"end": round(
word_info["end"] + start_time + timestamp_offset, 2
),
}
end=round(word_info["end"] + start_time + timestamp_offset, 2),
)
for word_info in output.timestamp["word"]
]
yield text, words
yield TranscriptResult(text, words)
upload_volume.reload()
@@ -393,10 +453,9 @@ class TranscriberParakeetFile:
audio_array = load_and_convert_audio(file_path)
total_duration = len(audio_array) / float(SAMPLERATE)
processed_duration = 0.0
all_text_parts = []
all_words = []
all_text_parts: list[str] = []
all_words: list[WordTiming] = []
raw_segments = vad_segment_generator(audio_array)
speech_segments = batch_speech_segments(
@@ -406,19 +465,19 @@ class TranscriberParakeetFile:
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
for batch in audio_segments:
_, _, audio_segment = batch
audio_segment = batch.audio
results = transcribe_batch(self.model, [audio_segment])
for text, words in emit_results(
for result in emit_results(
results,
[batch],
):
if not text:
if not result.text:
continue
all_text_parts.append(text)
all_words.extend(words)
all_text_parts.append(result.text)
all_words.extend(result.words)
processed_duration += len(audio_segment) / float(SAMPLERATE)
all_words = enforce_word_timing_constraints(all_words)
combined_text = " ".join(all_text_parts)
return {"text": combined_text, "words": all_words}