mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
style: more type annotations to parakeet transcriber (#581)
* feat: add comprehensive type annotations to Parakeet transcriber - Add TypedDict for WordTiming with word, start, end fields - Add NamedTuple for TimeSegment, AudioSegment, and TranscriptResult - Add type hints to all generator functions (vad_segment_generator, batch_speech_segments, etc.) - Add enforce_word_timing_constraints function to prevent word timing overlaps - Refactor batch_segment_to_audio_segment to reuse pad_audio function * doc: add note about space
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Generator, Mapping, NewType
|
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import modal
|
import modal
|
||||||
@@ -22,6 +22,37 @@ VAD_CONFIG = {
|
|||||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeSegment(NamedTuple):
|
||||||
|
"""Represents a time segment with start and end times."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSegment(NamedTuple):
|
||||||
|
"""Represents an audio segment with timing and audio data."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
audio: any
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptResult(NamedTuple):
|
||||||
|
"""Represents a transcription result with text and word timings."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
words: list["WordTiming"]
|
||||||
|
|
||||||
|
|
||||||
|
class WordTiming(TypedDict):
|
||||||
|
"""Represents a word with its timing information."""
|
||||||
|
|
||||||
|
word: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
|
||||||
|
|
||||||
app = modal.App("reflector-transcriber-parakeet")
|
app = modal.App("reflector-transcriber-parakeet")
|
||||||
|
|
||||||
# Volume for caching model weights
|
# Volume for caching model weights
|
||||||
@@ -167,12 +198,14 @@ class TranscriberParakeetLive:
|
|||||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||||
|
|
||||||
text = output.text.strip()
|
text = output.text.strip()
|
||||||
words = [
|
words: list[WordTiming] = [
|
||||||
{
|
WordTiming(
|
||||||
"word": word_info["word"] + " ",
|
# XXX the space added here is to match the output of whisper
|
||||||
"start": round(word_info["start"], 2),
|
# whisper add space to each words, while parakeet don't
|
||||||
"end": round(word_info["end"], 2),
|
word=word_info["word"] + " ",
|
||||||
}
|
start=round(word_info["start"], 2),
|
||||||
|
end=round(word_info["end"], 2),
|
||||||
|
)
|
||||||
for word_info in output.timestamp["word"]
|
for word_info in output.timestamp["word"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -208,12 +241,12 @@ class TranscriberParakeetLive:
|
|||||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||||
text = output.text.strip()
|
text = output.text.strip()
|
||||||
|
|
||||||
words = [
|
words: list[WordTiming] = [
|
||||||
{
|
WordTiming(
|
||||||
"word": word_info["word"] + " ",
|
word=word_info["word"] + " ",
|
||||||
"start": round(word_info["start"], 2),
|
start=round(word_info["start"], 2),
|
||||||
"end": round(word_info["end"], 2),
|
end=round(word_info["end"], 2),
|
||||||
}
|
)
|
||||||
for word_info in output.timestamp["word"]
|
for word_info in output.timestamp["word"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -270,7 +303,7 @@ class TranscriberParakeetFile:
|
|||||||
|
|
||||||
def vad_segment_generator(
|
def vad_segment_generator(
|
||||||
audio_array,
|
audio_array,
|
||||||
) -> Generator[tuple[float, float], None, None]:
|
) -> Generator[TimeSegment, None, None]:
|
||||||
"""Generate speech segments using VAD with start/end sample indices"""
|
"""Generate speech segments using VAD with start/end sample indices"""
|
||||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||||
window_size = VAD_CONFIG["window_size"]
|
window_size = VAD_CONFIG["window_size"]
|
||||||
@@ -296,14 +329,14 @@ class TranscriberParakeetFile:
|
|||||||
start_time = start / float(SAMPLERATE)
|
start_time = start / float(SAMPLERATE)
|
||||||
end_time = end / float(SAMPLERATE)
|
end_time = end / float(SAMPLERATE)
|
||||||
|
|
||||||
yield (start_time, end_time)
|
yield TimeSegment(start_time, end_time)
|
||||||
start = None
|
start = None
|
||||||
|
|
||||||
vad_iterator.reset_states()
|
vad_iterator.reset_states()
|
||||||
|
|
||||||
def batch_speech_segments(
|
def batch_speech_segments(
|
||||||
segments: Generator[tuple[float, float], None, None], max_duration: int
|
segments: Generator[TimeSegment, None, None], max_duration: int
|
||||||
) -> Generator[tuple[float, float], None, None]:
|
) -> Generator[TimeSegment, None, None]:
|
||||||
"""
|
"""
|
||||||
Input segments:
|
Input segments:
|
||||||
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
|
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
|
||||||
@@ -319,7 +352,8 @@ class TranscriberParakeetFile:
|
|||||||
batch_start_time = None
|
batch_start_time = None
|
||||||
batch_end_time = None
|
batch_end_time = None
|
||||||
|
|
||||||
for start_time, end_time in segments:
|
for segment in segments:
|
||||||
|
start_time, end_time = segment.start, segment.end
|
||||||
if batch_start_time is None or batch_end_time is None:
|
if batch_start_time is None or batch_end_time is None:
|
||||||
batch_start_time = start_time
|
batch_start_time = start_time
|
||||||
batch_end_time = end_time
|
batch_end_time = end_time
|
||||||
@@ -331,59 +365,85 @@ class TranscriberParakeetFile:
|
|||||||
batch_end_time = end_time
|
batch_end_time = end_time
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield (batch_start_time, batch_end_time)
|
yield TimeSegment(batch_start_time, batch_end_time)
|
||||||
batch_start_time = start_time
|
batch_start_time = start_time
|
||||||
batch_end_time = end_time
|
batch_end_time = end_time
|
||||||
|
|
||||||
if batch_start_time is None or batch_end_time is None:
|
if batch_start_time is None or batch_end_time is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
yield (batch_start_time, batch_end_time)
|
yield TimeSegment(batch_start_time, batch_end_time)
|
||||||
|
|
||||||
def batch_segment_to_audio_segment(segments, audio_array):
|
def batch_segment_to_audio_segment(
|
||||||
for start_time, end_time in segments:
|
segments: Generator[TimeSegment, None, None],
|
||||||
|
audio_array,
|
||||||
|
) -> Generator[AudioSegment, None, None]:
|
||||||
|
"""Extract audio segments and apply padding for Parakeet compatibility.
|
||||||
|
|
||||||
|
Uses pad_audio to ensure segments are at least 0.5s long, preventing
|
||||||
|
Parakeet crashes. This padding may cause slight timing overlaps between
|
||||||
|
segments, which are corrected by enforce_word_timing_constraints.
|
||||||
|
"""
|
||||||
|
for segment in segments:
|
||||||
|
start_time, end_time = segment.start, segment.end
|
||||||
start_sample = int(start_time * SAMPLERATE)
|
start_sample = int(start_time * SAMPLERATE)
|
||||||
end_sample = int(end_time * SAMPLERATE)
|
end_sample = int(end_time * SAMPLERATE)
|
||||||
audio_segment = audio_array[start_sample:end_sample]
|
audio_segment = audio_array[start_sample:end_sample]
|
||||||
|
|
||||||
if end_time - start_time < VAD_CONFIG["silence_padding"]:
|
padded_segment = pad_audio(audio_segment, SAMPLERATE)
|
||||||
silence_samples = int(
|
|
||||||
(VAD_CONFIG["silence_padding"] - (end_time - start_time))
|
|
||||||
* SAMPLERATE
|
|
||||||
)
|
|
||||||
padding = np.zeros(silence_samples, dtype=np.float32)
|
|
||||||
audio_segment = np.concatenate([audio_segment, padding])
|
|
||||||
|
|
||||||
yield start_time, end_time, audio_segment
|
yield AudioSegment(start_time, end_time, padded_segment)
|
||||||
|
|
||||||
def transcribe_batch(model, audio_segments):
|
def transcribe_batch(model, audio_segments: list) -> list:
|
||||||
with NoStdStreams():
|
with NoStdStreams():
|
||||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def enforce_word_timing_constraints(
|
||||||
|
words: list[WordTiming],
|
||||||
|
) -> list[WordTiming]:
|
||||||
|
"""Enforce that word end times don't exceed the start time of the next word.
|
||||||
|
|
||||||
|
Due to silence padding added in batch_segment_to_audio_segment for better
|
||||||
|
transcription accuracy, word timings from different segments may overlap.
|
||||||
|
This function ensures there are no overlaps by adjusting end times.
|
||||||
|
"""
|
||||||
|
if len(words) <= 1:
|
||||||
|
return words
|
||||||
|
|
||||||
|
enforced_words = []
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
enforced_word = word.copy()
|
||||||
|
|
||||||
|
if i < len(words) - 1:
|
||||||
|
next_start = words[i + 1]["start"]
|
||||||
|
if enforced_word["end"] > next_start:
|
||||||
|
enforced_word["end"] = next_start
|
||||||
|
|
||||||
|
enforced_words.append(enforced_word)
|
||||||
|
|
||||||
|
return enforced_words
|
||||||
|
|
||||||
def emit_results(
|
def emit_results(
|
||||||
results,
|
results: list,
|
||||||
segments_info,
|
segments_info: list[AudioSegment],
|
||||||
):
|
) -> Generator[TranscriptResult, None, None]:
|
||||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||||
for i, (output, (start_time, end_time, _)) in enumerate(
|
for i, (output, segment) in enumerate(zip(results, segments_info)):
|
||||||
zip(results, segments_info)
|
start_time, end_time = segment.start, segment.end
|
||||||
):
|
|
||||||
text = output.text.strip()
|
text = output.text.strip()
|
||||||
words = [
|
words: list[WordTiming] = [
|
||||||
{
|
WordTiming(
|
||||||
"word": word_info["word"] + " ",
|
word=word_info["word"] + " ",
|
||||||
"start": round(
|
start=round(
|
||||||
word_info["start"] + start_time + timestamp_offset, 2
|
word_info["start"] + start_time + timestamp_offset, 2
|
||||||
),
|
),
|
||||||
"end": round(
|
end=round(word_info["end"] + start_time + timestamp_offset, 2),
|
||||||
word_info["end"] + start_time + timestamp_offset, 2
|
)
|
||||||
),
|
|
||||||
}
|
|
||||||
for word_info in output.timestamp["word"]
|
for word_info in output.timestamp["word"]
|
||||||
]
|
]
|
||||||
|
|
||||||
yield text, words
|
yield TranscriptResult(text, words)
|
||||||
|
|
||||||
upload_volume.reload()
|
upload_volume.reload()
|
||||||
|
|
||||||
@@ -393,10 +453,9 @@ class TranscriberParakeetFile:
|
|||||||
|
|
||||||
audio_array = load_and_convert_audio(file_path)
|
audio_array = load_and_convert_audio(file_path)
|
||||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||||
processed_duration = 0.0
|
|
||||||
|
|
||||||
all_text_parts = []
|
all_text_parts: list[str] = []
|
||||||
all_words = []
|
all_words: list[WordTiming] = []
|
||||||
|
|
||||||
raw_segments = vad_segment_generator(audio_array)
|
raw_segments = vad_segment_generator(audio_array)
|
||||||
speech_segments = batch_speech_segments(
|
speech_segments = batch_speech_segments(
|
||||||
@@ -406,19 +465,19 @@ class TranscriberParakeetFile:
|
|||||||
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
|
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
|
||||||
|
|
||||||
for batch in audio_segments:
|
for batch in audio_segments:
|
||||||
_, _, audio_segment = batch
|
audio_segment = batch.audio
|
||||||
results = transcribe_batch(self.model, [audio_segment])
|
results = transcribe_batch(self.model, [audio_segment])
|
||||||
|
|
||||||
for text, words in emit_results(
|
for result in emit_results(
|
||||||
results,
|
results,
|
||||||
[batch],
|
[batch],
|
||||||
):
|
):
|
||||||
if not text:
|
if not result.text:
|
||||||
continue
|
continue
|
||||||
all_text_parts.append(text)
|
all_text_parts.append(result.text)
|
||||||
all_words.extend(words)
|
all_words.extend(result.words)
|
||||||
|
|
||||||
processed_duration += len(audio_segment) / float(SAMPLERATE)
|
all_words = enforce_word_timing_constraints(all_words)
|
||||||
|
|
||||||
combined_text = " ".join(all_text_parts)
|
combined_text = " ".join(all_text_parts)
|
||||||
return {"text": combined_text, "words": all_words}
|
return {"text": combined_text, "words": all_words}
|
||||||
|
|||||||
Reference in New Issue
Block a user