diff --git a/server/gpu/modal_deployments/reflector_transcriber_parakeet.py b/server/gpu/modal_deployments/reflector_transcriber_parakeet.py index 3b6f6ad0..0827f0cc 100644 --- a/server/gpu/modal_deployments/reflector_transcriber_parakeet.py +++ b/server/gpu/modal_deployments/reflector_transcriber_parakeet.py @@ -3,7 +3,7 @@ import os import sys import threading import uuid -from typing import Generator, Mapping, NewType +from typing import Generator, Mapping, NamedTuple, NewType, TypedDict from urllib.parse import urlparse import modal @@ -22,6 +22,37 @@ VAD_CONFIG = { ParakeetUniqFilename = NewType("ParakeetUniqFilename", str) AudioFileExtension = NewType("AudioFileExtension", str) + +class TimeSegment(NamedTuple): + """Represents a time segment with start and end times.""" + + start: float + end: float + + +class AudioSegment(NamedTuple): + """Represents an audio segment with timing and audio data.""" + + start: float + end: float + audio: any + + +class TranscriptResult(NamedTuple): + """Represents a transcription result with text and word timings.""" + + text: str + words: list["WordTiming"] + + +class WordTiming(TypedDict): + """Represents a word with its timing information.""" + + word: str + start: float + end: float + + app = modal.App("reflector-transcriber-parakeet") # Volume for caching model weights @@ -167,12 +198,14 @@ class TranscriberParakeetLive: (output,) = self.model.transcribe([padded_audio], timestamps=True) text = output.text.strip() - words = [ - { - "word": word_info["word"] + " ", - "start": round(word_info["start"], 2), - "end": round(word_info["end"], 2), - } + words: list[WordTiming] = [ + WordTiming( + # XXX the space added here is to match the output of whisper + # whisper add space to each words, while parakeet don't + word=word_info["word"] + " ", + start=round(word_info["start"], 2), + end=round(word_info["end"], 2), + ) for word_info in output.timestamp["word"] ] @@ -208,12 +241,12 @@ class TranscriberParakeetLive: for i, (filename, output) in enumerate(zip(filenames, outputs)): text = output.text.strip() - words = [ - { - "word": word_info["word"] + " ", - "start": round(word_info["start"], 2), - "end": round(word_info["end"], 2), - } + words: list[WordTiming] = [ + WordTiming( + word=word_info["word"] + " ", + start=round(word_info["start"], 2), + end=round(word_info["end"], 2), + ) for word_info in output.timestamp["word"] ] @@ -270,7 +303,7 @@ class TranscriberParakeetFile: def vad_segment_generator( audio_array, - ) -> Generator[tuple[float, float], None, None]: + ) -> Generator[TimeSegment, None, None]: """Generate speech segments using VAD with start/end sample indices""" vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE) window_size = VAD_CONFIG["window_size"] @@ -296,14 +329,14 @@ class TranscriberParakeetFile: start_time = start / float(SAMPLERATE) end_time = end / float(SAMPLERATE) - yield (start_time, end_time) + yield TimeSegment(start_time, end_time) start = None vad_iterator.reset_states() def batch_speech_segments( - segments: Generator[tuple[float, float], None, None], max_duration: int - ) -> Generator[tuple[float, float], None, None]: + segments: Generator[TimeSegment, None, None], max_duration: int + ) -> Generator[TimeSegment, None, None]: """ Input segments: [0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22] @@ -319,7 +352,8 @@ class TranscriberParakeetFile: batch_start_time = None batch_end_time = None - for start_time, end_time in segments: + for segment in segments: + start_time, end_time = segment.start, segment.end if batch_start_time is None or batch_end_time is None: batch_start_time = start_time batch_end_time = end_time @@ -331,59 +365,85 @@ class TranscriberParakeetFile: batch_end_time = end_time continue - yield (batch_start_time, batch_end_time) + yield TimeSegment(batch_start_time, batch_end_time) batch_start_time = start_time batch_end_time = end_time if batch_start_time is None or batch_end_time is None: return - yield (batch_start_time, batch_end_time) + yield TimeSegment(batch_start_time, batch_end_time) - def batch_segment_to_audio_segment(segments, audio_array): - for start_time, end_time in segments: + def batch_segment_to_audio_segment( + segments: Generator[TimeSegment, None, None], + audio_array, + ) -> Generator[AudioSegment, None, None]: + """Extract audio segments and apply padding for Parakeet compatibility. + + Uses pad_audio to ensure segments are at least 0.5s long, preventing + Parakeet crashes. This padding may cause slight timing overlaps between + segments, which are corrected by enforce_word_timing_constraints. + """ + for segment in segments: + start_time, end_time = segment.start, segment.end start_sample = int(start_time * SAMPLERATE) end_sample = int(end_time * SAMPLERATE) audio_segment = audio_array[start_sample:end_sample] - if end_time - start_time < VAD_CONFIG["silence_padding"]: - silence_samples = int( - (VAD_CONFIG["silence_padding"] - (end_time - start_time)) - * SAMPLERATE - ) - padding = np.zeros(silence_samples, dtype=np.float32) - audio_segment = np.concatenate([audio_segment, padding]) + padded_segment = pad_audio(audio_segment, SAMPLERATE) - yield start_time, end_time, audio_segment + yield AudioSegment(start_time, end_time, padded_segment) - def transcribe_batch(model, audio_segments): + def transcribe_batch(model, audio_segments: list) -> list: with NoStdStreams(): outputs = model.transcribe(audio_segments, timestamps=True) return outputs + def enforce_word_timing_constraints( + words: list[WordTiming], + ) -> list[WordTiming]: + """Enforce that word end times don't exceed the start time of the next word. + + Due to silence padding added in batch_segment_to_audio_segment for better + transcription accuracy, word timings from different segments may overlap. + This function ensures there are no overlaps by adjusting end times. + """ + if len(words) <= 1: + return words + + enforced_words = [] + for i, word in enumerate(words): + enforced_word = word.copy() + + if i < len(words) - 1: + next_start = words[i + 1]["start"] + if enforced_word["end"] > next_start: + enforced_word["end"] = next_start + + enforced_words.append(enforced_word) + + return enforced_words + def emit_results( - results, - segments_info, - ): + results: list, + segments_info: list[AudioSegment], + ) -> Generator[TranscriptResult, None, None]: """Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions.""" - for i, (output, (start_time, end_time, _)) in enumerate( - zip(results, segments_info) - ): + for i, (output, segment) in enumerate(zip(results, segments_info)): + start_time, end_time = segment.start, segment.end text = output.text.strip() - words = [ - { - "word": word_info["word"] + " ", - "start": round( + words: list[WordTiming] = [ + WordTiming( + word=word_info["word"] + " ", + start=round( word_info["start"] + start_time + timestamp_offset, 2 ), - "end": round( - word_info["end"] + start_time + timestamp_offset, 2 - ), - } + end=round(word_info["end"] + start_time + timestamp_offset, 2), + ) for word_info in output.timestamp["word"] ] - yield text, words + yield TranscriptResult(text, words) upload_volume.reload() @@ -393,10 +453,9 @@ class TranscriberParakeetFile: audio_array = load_and_convert_audio(file_path) total_duration = len(audio_array) / float(SAMPLERATE) - processed_duration = 0.0 - all_text_parts = [] - all_words = [] + all_text_parts: list[str] = [] + all_words: list[WordTiming] = [] raw_segments = vad_segment_generator(audio_array) speech_segments = batch_speech_segments( @@ -406,19 +465,19 @@ class TranscriberParakeetFile: audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array) for batch in audio_segments: - _, _, audio_segment = batch + audio_segment = batch.audio results = transcribe_batch(self.model, [audio_segment]) - for text, words in emit_results( + for result in emit_results( results, [batch], ): - if not text: + if not result.text: continue - all_text_parts.append(text) - all_words.extend(words) + all_text_parts.append(result.text) + all_words.extend(result.words) - processed_duration += len(audio_segment) / float(SAMPLERATE) + all_words = enforce_word_timing_constraints(all_words) combined_text = " ".join(all_text_parts) return {"text": combined_text, "words": all_words}