feat: pipeline improvement with file processing, parakeet, silero-vad (#540)

* feat: improve pipeline threading, and transcriber (parakeet and silero vad)

* refactor: remove whisperx, implement parakeet

* refactor: make audio_chunker more smart and wait for speech, instead of fixed frame

* refactor: make audio merge to always downscale the audio to 16k for transcription

* refactor: make the audio transcript modal accepting batches

* refactor: improve type safety and remove prometheus metrics

- Add DiarizationSegment TypedDict for proper diarization typing
- Replace List/Optional with modern Python list/| None syntax
- Remove all Prometheus metrics from TranscriptDiarizationAssemblerProcessor
- Add comprehensive file processing pipeline with parallel execution
- Update processor imports and type annotations throughout
- Implement optimized file pipeline as default in process.py tool

* refactor: convert FileDiarizationProcessor I/O types to BaseModel

Update FileDiarizationInput and FileDiarizationOutput to inherit from
BaseModel instead of plain classes, following the standard pattern
used by other processors in the codebase.

* test: add tests for file transcript and diarization with pytest-recording

* build: add pytest-recording

* feat: add local pyannote for testing

* fix: replace PyAV AudioResampler with torchaudio for reliable audio processing

- Replace problematic PyAV AudioResampler that was causing ValueError: [Errno 22] Invalid argument
- Use torchaudio.functional.resample for robust sample rate conversion
- Optimize processing: skip conversion for already 16kHz mono audio
- Add direct WAV writing with Python wave module for better performance
- Consolidate duplicate downsample checks for cleaner code
- Maintain list[av.AudioFrame] input interface
- Required for Silero VAD which needs 16kHz mono audio

* fix: replace PyAV AudioResampler with torchaudio solution

- Resolves ValueError: [Errno 22] Invalid argument in AudioMergeProcessor
- Replaces problematic PyAV AudioResampler with torchaudio.functional.resample
- Optimizes processing to skip unnecessary conversions when audio is already 16kHz mono
- Uses direct WAV writing with Python's wave module for better performance
- Fixes test_basic_process to disable diarization (pyannote dependency not installed)
- Updates test expectations to match actual processor behavior
- Removes unused pydub dependency from pyproject.toml
- Adds comprehensive TEST_ANALYSIS.md documenting test suite status

* feat: add parameterized test for both diarization modes

- Adds @pytest.mark.parametrize to test_basic_process with enable_diarization=[False, True]
- Test with diarization=False always passes (tests core AudioMergeProcessor functionality)
- Test with diarization=True gracefully skips when pyannote.audio is not installed
- Provides comprehensive test coverage for both pipeline configurations

* fix: resolve pipeline property naming conflict in AudioDiarizationPyannoteProcessor

- Renames 'pipeline' property to 'diarization_pipeline' to avoid conflict with base Processor.pipeline attribute
- Fixes AttributeError: 'property 'pipeline' object has no setter' when set_pipeline() is called
- Updates property usage in _diarize method to use new name
- Now correctly supports pipeline initialization for diarization processing

* fix: add local for pyannote

* test: add diarization test

* fix: resample on audio merge now working

* fix: correctly restore timestamp

* fix: display exception in a threaded processor if that happen

* Update pyproject.toml

* ci: remove option

* ci: update astral-sh/setup-uv

* test: add monadical url for pytest-recording

* refactor: remove previous version

* build: move faster whisper to local dep

* test: fix missing import

* refactor: improve main_file_pipeline organization and error handling

- Move all imports to the top of the file
- Create unified EmptyPipeline class to replace duplicate mock pipeline code
- Remove timeout and fallback logic - let processors handle their own retries
- Fix error handling to raise any exception from parallel tasks
- Add proper type hints and validation for captured results

* fix: wrong function

* fix: remove task_done

* feat: add configurable file processing timeouts for modal processors

- Add TRANSCRIPT_FILE_TIMEOUT setting (default: 600s) for file transcription
- Add DIARIZATION_FILE_TIMEOUT setting (default: 600s) for file diarization
- Replace hardcoded timeout=600 with configurable settings in modal processors
- Allows customization of timeout values via environment variables

* fix: use logger

* fix: worker process meetings now use file pipeline

* fix: topic not gathered

* refactor: remove prepare(), pipeline now work

* refactor: implement many review from Igor

* test: add test for test_pipeline_main_file

* refactor: remove doc

* doc: add doc

* ci: update build to use native arm64 builder

* fix: merge fixes

* refactor: changes from Igor review + add test (not by default) to test gpu modal part

* ci: update to our own runner linux-amd64

* ci: try using suggested mode=min

* fix: update diarizer for latest modal, and use volume

* fix: modal file extension detection

* fix: put the diarizer as A100
This commit is contained in:
2025-08-20 20:07:19 -06:00
committed by GitHub
parent 009590c080
commit 3ea7f6b7b6
37 changed files with 5086 additions and 198 deletions

View File

@@ -11,6 +11,13 @@ from .base import ( # noqa: F401
Processor,
ThreadedProcessor,
)
from .file_diarization import FileDiarizationProcessor # noqa: F401
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
from .file_transcript import FileTranscriptProcessor # noqa: F401
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
from .transcript_diarization_assembler import (
TranscriptDiarizationAssemblerProcessor, # noqa: F401
)
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
from .transcript_liner import TranscriptLinerProcessor # noqa: F401

View File

@@ -1,28 +1,340 @@
from typing import Optional
import av
import numpy as np
import torch
from silero_vad import VADIterator, load_silero_vad
from reflector.processors.base import Processor
class AudioChunkerProcessor(Processor):
"""
Assemble audio frames into chunks
Assemble audio frames into chunks with VAD-based speech detection
"""
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = list[av.AudioFrame]
def __init__(self, max_frames=256):
def __init__(
self,
block_frames=256,
max_frames=1024,
vad_threshold=0.5,
use_onnx=False,
min_frames=2,
):
super().__init__()
self.frames: list[av.AudioFrame] = []
self.block_frames = block_frames
self.max_frames = max_frames
self.vad_threshold = vad_threshold
self.min_frames = min_frames
# Initialize Silero VAD
self._init_vad(use_onnx)
def _init_vad(self, use_onnx=False):
"""Initialize Silero VAD model"""
try:
torch.set_num_threads(1)
self.vad_model = load_silero_vad(onnx=use_onnx)
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
self.logger.info("Silero VAD initialized successfully")
except Exception as e:
self.logger.error(f"Failed to initialize Silero VAD: {e}")
self.vad_model = None
self.vad_iterator = None
async def _push(self, data: av.AudioFrame):
self.frames.append(data)
if len(self.frames) >= self.max_frames:
await self.flush()
# print("timestamp", data.pts * data.time_base * 1000)
# Check for speech segments every 32 frames (~1 second)
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
await self._process_block()
# Safety fallback - emit if we hit max frames
elif len(self.frames) >= self.max_frames:
self.logger.warning(
f"AudioChunkerProcessor: Reached max frames ({self.max_frames}), "
f"emitting first {self.max_frames // 2} frames"
)
frames_to_emit = self.frames[: self.max_frames // 2]
self.frames = self.frames[self.max_frames // 2 :]
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
async def _process_block(self):
# Need at least 32 frames for VAD detection (~1 second)
if len(self.frames) < 32 or self.vad_iterator is None:
return
# Processing block with current buffer size
# print(f"Processing block: {len(self.frames)} frames in buffer")
try:
# Convert frames to numpy array for VAD
audio_array = self._frames_to_numpy(self.frames)
if audio_array is None:
# Fallback: emit all frames if conversion failed
frames_to_emit = self.frames[:]
self.frames = []
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
return
# Find complete speech segments in the buffer
speech_end_frame = self._find_speech_segment_end(audio_array)
if speech_end_frame is None or speech_end_frame <= 0:
# No speech found but buffer is getting large
if len(self.frames) > 512:
# Check if it's all silence and can be discarded
# No speech segment found, buffer at {len(self.frames)} frames
# Could emit silence or discard old frames here
# For now, keep first 256 frames and discard older silence
if len(self.frames) > 768:
self.logger.debug(
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
)
self.frames = self.frames[-256:]
return
# Calculate segment timing information
frames_to_emit = self.frames[:speech_end_frame]
# Get timing from av.AudioFrame
if frames_to_emit:
first_frame = frames_to_emit[0]
last_frame = frames_to_emit[-1]
sample_rate = first_frame.sample_rate
# Calculate duration
total_samples = sum(f.samples for f in frames_to_emit)
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
# Get timestamps if available
start_time = (
first_frame.pts * first_frame.time_base if first_frame.pts else 0
)
end_time = (
last_frame.pts * last_frame.time_base if last_frame.pts else 0
)
# Convert to HH:MM:SS format for logging
def format_time(seconds):
if not seconds:
return "00:00:00"
total_seconds = int(float(seconds))
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60
secs = total_seconds % 60
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
start_formatted = format_time(start_time)
end_formatted = format_time(end_time)
# Keep remaining frames for next processing
remaining_after = len(self.frames) - speech_end_frame
# Single structured log line
self.logger.info(
"Speech segment found",
start=start_formatted,
end=end_formatted,
frames=speech_end_frame,
duration=round(duration_seconds, 2),
buffer_before=len(self.frames),
remaining=remaining_after,
)
# Keep remaining frames for next processing
self.frames = self.frames[speech_end_frame:]
# Filter out segments with too few frames
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
except Exception as e:
self.logger.error(f"Error in VAD processing: {e}")
# Fallback to simple chunking
if len(self.frames) >= self.block_frames:
frames_to_emit = self.frames[: self.block_frames]
self.frames = self.frames[self.block_frames :]
if len(frames_to_emit) >= self.min_frames:
await self.emit(frames_to_emit)
else:
self.logger.debug(
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
f"(< {self.min_frames} minimum)"
)
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
"""Convert av.AudioFrame list to numpy array for VAD processing"""
if not frames:
return None
try:
first_frame = frames[0]
original_sample_rate = first_frame.sample_rate
audio_data = []
for frame in frames:
frame_array = frame.to_ndarray()
# Handle stereo -> mono conversion
if len(frame_array.shape) == 2 and frame_array.shape[0] > 1:
frame_array = np.mean(frame_array, axis=0)
elif len(frame_array.shape) == 2:
frame_array = frame_array.flatten()
audio_data.append(frame_array)
if not audio_data:
return None
combined_audio = np.concatenate(audio_data)
# Resample from 48kHz to 16kHz if needed
if original_sample_rate != 16000:
combined_audio = self._resample_audio(
combined_audio, original_sample_rate, 16000
)
# Ensure float32 format
if combined_audio.dtype == np.int16:
# Normalize int16 audio to float32 in range [-1.0, 1.0]
combined_audio = combined_audio.astype(np.float32) / 32768.0
elif combined_audio.dtype != np.float32:
combined_audio = combined_audio.astype(np.float32)
return combined_audio
except Exception as e:
self.logger.error(f"Error converting frames to numpy: {e}")
return None
def _resample_audio(
self, audio: np.ndarray, from_sr: int, to_sr: int
) -> np.ndarray:
"""Simple linear resampling from from_sr to to_sr"""
if from_sr == to_sr:
return audio
try:
# Simple linear interpolation resampling
ratio = to_sr / from_sr
new_length = int(len(audio) * ratio)
# Create indices for interpolation
old_indices = np.linspace(0, len(audio) - 1, new_length)
resampled = np.interp(old_indices, np.arange(len(audio)), audio)
return resampled.astype(np.float32)
except Exception as e:
self.logger.error("Resampling error", exc_info=e)
# Fallback: simple decimation/repetition
if from_sr > to_sr:
# Downsample by taking every nth sample
step = from_sr // to_sr
return audio[::step]
else:
# Upsample by repeating samples
repeat = to_sr // from_sr
return np.repeat(audio, repeat)
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
"""Find complete speech segments and return frame index at segment end"""
if self.vad_iterator is None or len(audio_array) == 0:
return None
try:
# Process audio in 512-sample windows for VAD
window_size = 512
min_silence_windows = 3 # Require 3 windows of silence after speech
# Track speech state
in_speech = False
speech_start = None
speech_end = None
silence_count = 0
for i in range(0, len(audio_array), window_size):
chunk = audio_array[i : i + window_size]
if len(chunk) < window_size:
chunk = np.pad(chunk, (0, window_size - len(chunk)))
# Detect if this window has speech
speech_dict = self.vad_iterator(chunk, return_seconds=True)
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
if speech_dict:
if not in_speech:
# Speech started
speech_start = i
in_speech = True
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
silence_count = 0 # Reset silence counter
continue
if not in_speech:
continue
# We're in speech but found silence
silence_count += 1
if silence_count < min_silence_windows:
continue
# Found end of speech segment
speech_end = i - (min_silence_windows - 1) * window_size
# Debug: print(f"Speech END at sample {speech_end}")
# Convert sample position to frame index
samples_per_frame = self.frames[0].samples if self.frames else 1024
# Account for resampling: we process at 16kHz but frames might be 48kHz
resample_ratio = 48000 / 16000 # 3x
actual_sample_pos = int(speech_end * resample_ratio)
frame_index = actual_sample_pos // samples_per_frame
# Ensure we don't exceed buffer
frame_index = min(frame_index, len(self.frames))
return frame_index
return None
except Exception as e:
self.logger.error(f"Error finding speech segment: {e}")
return None
async def _flush(self):
frames = self.frames[:]
self.frames = []
if frames:
await self.emit(frames)
if len(frames) >= self.min_frames:
await self.emit(frames)
else:
self.logger.debug(
f"Ignoring flush segment with {len(frames)} frames "
f"(< {self.min_frames} minimum)"
)

View File

@@ -1,6 +1,7 @@
from reflector.processors.base import Processor
from reflector.processors.types import (
AudioDiarizationInput,
DiarizationSegment,
TitleSummary,
Word,
)
@@ -38,7 +39,7 @@ class AudioDiarizationProcessor(Processor):
raise NotImplementedError
@classmethod
def assign_speaker(cls, words: list[Word], diarization: list[dict]):
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
cls._diarization_remove_overlap(diarization)
cls._diarization_remove_segment_without_words(words, diarization)
cls._diarization_merge_same_speaker(diarization)
@@ -65,7 +66,7 @@ class AudioDiarizationProcessor(Processor):
return True
@staticmethod
def _diarization_remove_overlap(diarization: list[dict]):
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
"""
Remove overlap in diarization results
@@ -92,7 +93,7 @@ class AudioDiarizationProcessor(Processor):
@staticmethod
def _diarization_remove_segment_without_words(
words: list[Word], diarization: list[dict]
words: list[Word], diarization: list[DiarizationSegment]
):
"""
Remove diarization segments without words
@@ -122,7 +123,7 @@ class AudioDiarizationProcessor(Processor):
diarization_idx += 1
@staticmethod
def _diarization_merge_same_speaker(diarization: list[dict]):
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
"""
Merge diarization contigous segments with the same speaker
@@ -140,7 +141,9 @@ class AudioDiarizationProcessor(Processor):
diarization_idx += 1
@classmethod
def _diarization_assign_speaker(cls, words: list[Word], diarization: list[dict]):
def _diarization_assign_speaker(
cls, words: list[Word], diarization: list[DiarizationSegment]
):
"""
Assign speaker to words based on diarization
@@ -148,7 +151,7 @@ class AudioDiarizationProcessor(Processor):
"""
word_idx = 0
last_speaker = None
last_speaker = 0
for d in diarization:
start = d["start"]
end = d["end"]

View File

@@ -0,0 +1,74 @@
import os
import torch
import torchaudio
from pyannote.audio import Pipeline
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
"""Local diarization processor using pyannote.audio library"""
def __init__(
self,
model_name: str = "pyannote/speaker-diarization-3.1",
pyannote_auth_token: str | None = None,
device: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.model_name = model_name
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
self.device = device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
self.diarization_pipeline = Pipeline.from_pretrained(
self.model_name, use_auth_token=self.auth_token
)
self.diarization_pipeline.to(torch.device(self.device))
self.logger.info(f"Diarization model loaded on device: {self.device}")
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
try:
# Load audio file (audio_url is assumed to be a local file path)
self.logger.info(f"Loading local audio file: {data.audio_url}")
waveform, sample_rate = torchaudio.load(data.audio_url)
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
self.logger.info("Running speaker diarization")
diarization = self.diarization_pipeline(audio_input)
# Convert pyannote diarization output to our format
segments = []
for segment, _, speaker in diarization.itertracks(yield_label=True):
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
speaker_id = 0
if speaker.startswith("SPEAKER_"):
try:
speaker_id = int(speaker.split("_")[-1])
except (ValueError, IndexError):
# Fallback to hash-based ID if parsing fails
speaker_id = hash(speaker) % 1000
segments.append(
{
"start": round(segment.start, 3),
"end": round(segment.end, 3),
"speaker": speaker_id,
}
)
self.logger.info(f"Diarization completed with {len(segments)} segments")
return segments
except Exception as e:
self.logger.exception(f"Diarization failed: {e}")
raise
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)

View File

@@ -3,11 +3,24 @@ from time import monotonic_ns
from uuid import uuid4
import av
from av.audio.resampler import AudioResampler
from reflector.processors.base import Processor
from reflector.processors.types import AudioFile
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
frame_copy = frame.from_ndarray(
frame.to_ndarray(),
format=frame.format.name,
layout=frame.layout.name,
)
frame_copy.sample_rate = frame.sample_rate
frame_copy.pts = frame.pts
frame_copy.time_base = frame.time_base
return frame_copy
class AudioMergeProcessor(Processor):
"""
Merge audio frame into a single file
@@ -16,37 +29,92 @@ class AudioMergeProcessor(Processor):
INPUT_TYPE = list[av.AudioFrame]
OUTPUT_TYPE = AudioFile
def __init__(self, downsample_to_16k_mono: bool = True, **kwargs):
super().__init__(**kwargs)
self.downsample_to_16k_mono = downsample_to_16k_mono
async def _push(self, data: list[av.AudioFrame]):
if not data:
return
# get audio information from first frame
frame = data[0]
channels = len(frame.layout.channels)
sample_rate = frame.sample_rate
sample_width = frame.format.bytes
original_channels = len(frame.layout.channels)
original_sample_rate = frame.sample_rate
original_sample_width = frame.format.bytes
# determine if we need processing
needs_processing = self.downsample_to_16k_mono and (
original_sample_rate != 16000 or original_channels != 1
)
# determine output parameters
if self.downsample_to_16k_mono:
output_sample_rate = 16000
output_channels = 1
output_sample_width = 2 # 16-bit = 2 bytes
else:
output_sample_rate = original_sample_rate
output_channels = original_channels
output_sample_width = original_sample_width
# create audio file
uu = uuid4().hex
fd = io.BytesIO()
out_container = av.open(fd, "w", format="wav")
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
for frame in data:
for packet in out_stream.encode(frame):
if needs_processing:
# Process with PyAV resampler
out_container = av.open(fd, "w", format="wav")
out_stream = out_container.add_stream("pcm_s16le", rate=16000)
out_stream.layout = "mono"
# Create resampler if needed
resampler = None
if original_sample_rate != 16000 or original_channels != 1:
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
for frame in data:
if resampler:
# Resample and convert to mono
# XXX for an unknown reason, if we don't use a copy of the frame, we get
# Invalid Argumment from resample. Debugging indicate that when a previous processor
# already used the frame (like AudioFileWriter), it make it invalid argument here.
resampled_frames = resampler.resample(copy_frame(frame))
for resampled_frame in resampled_frames:
for packet in out_stream.encode(resampled_frame):
out_container.mux(packet)
else:
# Direct encoding without resampling
for packet in out_stream.encode(frame):
out_container.mux(packet)
# Flush the encoder
for packet in out_stream.encode(None):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
out_container.close()
out_container.close()
else:
# Use PyAV for original frames (no processing needed)
out_container = av.open(fd, "w", format="wav")
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
out_stream.layout = "mono" if output_channels == 1 else frame.layout
for frame in data:
for packet in out_stream.encode(frame):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
out_container.close()
fd.seek(0)
# emit audio file
audiofile = AudioFile(
name=f"{monotonic_ns()}-{uu}.wav",
fd=fd,
sample_rate=sample_rate,
channels=channels,
sample_width=sample_width,
sample_rate=output_sample_rate,
channels=output_channels,
sample_width=output_sample_width,
timestamp=data[0].pts * data[0].time_base,
)

View File

@@ -12,6 +12,9 @@ API will be a POST request to TRANSCRIPT_URL:
"""
from typing import List
import aiohttp
from openai import AsyncOpenAI
from reflector.processors.audio_transcript import AudioTranscriptProcessor
@@ -21,7 +24,9 @@ from reflector.settings import settings
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
def __init__(self, modal_api_key: str | None = None, **kwargs):
def __init__(
self, modal_api_key: str | None = None, batch_enabled: bool = True, **kwargs
):
super().__init__()
if not settings.TRANSCRIPT_URL:
raise Exception(
@@ -30,6 +35,126 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.modal_api_key = modal_api_key
self.max_batch_duration = 10.0
self.max_batch_files = 15
self.batch_enabled = batch_enabled
self.pending_files: List[AudioFile] = [] # Files waiting to be processed
@classmethod
def _calculate_duration(cls, audio_file: AudioFile) -> float:
"""Calculate audio duration in seconds from AudioFile metadata"""
# Duration = total_samples / sample_rate
# We need to estimate total samples from the file data
import wave
try:
# Try to read as WAV file to get duration
audio_file.fd.seek(0)
with wave.open(audio_file.fd, "rb") as wav_file:
frames = wav_file.getnframes()
sample_rate = wav_file.getframerate()
duration = frames / sample_rate
return duration
except Exception:
# Fallback: estimate from file size and audio parameters
audio_file.fd.seek(0, 2) # Seek to end
file_size = audio_file.fd.tell()
audio_file.fd.seek(0) # Reset to beginning
# Estimate: file_size / (sample_rate * channels * sample_width)
bytes_per_second = (
audio_file.sample_rate
* audio_file.channels
* (audio_file.sample_width // 8)
)
estimated_duration = (
file_size / bytes_per_second if bytes_per_second > 0 else 0
)
return max(0, estimated_duration)
def _create_batches(self, audio_files: List[AudioFile]) -> List[List[AudioFile]]:
"""Group audio files into batches with maximum 30s total duration"""
batches = []
current_batch = []
current_duration = 0.0
for audio_file in audio_files:
duration = self._calculate_duration(audio_file)
# If adding this file exceeds max duration, start a new batch
if current_duration + duration > self.max_batch_duration and current_batch:
batches.append(current_batch)
current_batch = [audio_file]
current_duration = duration
else:
current_batch.append(audio_file)
current_duration += duration
# Add the last batch if not empty
if current_batch:
batches.append(current_batch)
return batches
async def _transcript_batch(self, audio_files: List[AudioFile]) -> List[Transcript]:
"""Transcribe a batch of audio files using the parakeet backend"""
if not audio_files:
return []
self.logger.debug(f"Batch transcribing {len(audio_files)} files")
# Prepare form data for batch request
data = aiohttp.FormData()
data.add_field("language", self.get_pref("audio:source_language", "en"))
data.add_field("batch", "true")
for i, audio_file in enumerate(audio_files):
audio_file.fd.seek(0)
data.add_field(
"files",
audio_file.fd,
filename=f"{audio_file.name}",
content_type="audio/wav",
)
# Make batch request
headers = {"Authorization": f"Bearer {self.modal_api_key}"}
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.post(
f"{self.transcript_url}/audio/transcriptions",
data=data,
headers=headers,
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(
f"Batch transcription failed: {response.status} {error_text}"
)
result = await response.json()
# Process batch results
transcripts = []
results = result.get("results", [])
for i, (audio_file, file_result) in enumerate(zip(audio_files, results)):
transcript = Transcript(
words=[
Word(
text=word_info["word"],
start=word_info["start"],
end=word_info["end"],
)
for word_info in file_result.get("words", [])
]
)
transcript.add_offset(audio_file.timestamp)
transcripts.append(transcript)
return transcripts
async def _transcript(self, data: AudioFile):
async with AsyncOpenAI(
@@ -62,5 +187,96 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
return transcript
async def transcript_multiple(
self, audio_files: List[AudioFile]
) -> List[Transcript]:
"""Transcribe multiple audio files using batching"""
if len(audio_files) == 1:
# Single file, use existing method
return [await self._transcript(audio_files[0])]
# Create batches with max 30s duration each
batches = self._create_batches(audio_files)
self.logger.debug(
f"Processing {len(audio_files)} files in {len(batches)} batches"
)
# Process all batches concurrently
all_transcripts = []
for batch in batches:
batch_transcripts = await self._transcript_batch(batch)
all_transcripts.extend(batch_transcripts)
return all_transcripts
async def _push(self, data: AudioFile):
"""Override _push to support batching"""
if not self.batch_enabled:
# Use parent implementation for single file processing
return await super()._push(data)
# Add file to pending batch
self.pending_files.append(data)
self.logger.debug(
f"Added file to batch: {data.name}, batch size: {len(self.pending_files)}"
)
# Calculate total duration of pending files
total_duration = sum(self._calculate_duration(f) for f in self.pending_files)
# Process batch if it reaches max duration or has multiple files ready for optimization
should_process_batch = (
total_duration >= self.max_batch_duration
or len(self.pending_files) >= self.max_batch_files
)
if should_process_batch:
await self._process_pending_batch()
async def _process_pending_batch(self):
"""Process all pending files as batches"""
if not self.pending_files:
return
self.logger.debug(f"Processing batch of {len(self.pending_files)} files")
try:
# Create batches respecting duration limit
batches = self._create_batches(self.pending_files)
# Process each batch
for batch in batches:
self.m_transcript_call.inc()
try:
with self.m_transcript.time():
# Use batch transcription
transcripts = await self._transcript_batch(batch)
self.m_transcript_success.inc()
# Emit each transcript
for transcript in transcripts:
if transcript:
await self.emit(transcript)
except Exception:
self.m_transcript_failure.inc()
raise
finally:
# Release audio files
for audio_file in batch:
audio_file.release()
finally:
# Clear pending files
self.pending_files.clear()
async def _flush(self):
"""Process any remaining files when flushing"""
await self._process_pending_batch()
await super()._flush()
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)

View File

@@ -241,33 +241,45 @@ class ThreadedProcessor(Processor):
self.INPUT_TYPE = processor.INPUT_TYPE
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.queue = asyncio.Queue()
self.task = asyncio.get_running_loop().create_task(self.loop())
self.queue = asyncio.Queue(maxsize=50)
self.task: asyncio.Task | None = None
def set_pipeline(self, pipeline: "Pipeline"):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
async def loop(self):
while True:
data = await self.queue.get()
self.m_processor_queue.set(self.queue.qsize())
with self.m_processor_queue_in_progress.track_inprogress():
try:
if data is None:
await self.processor.flush()
break
try:
while True:
data = await self.queue.get()
self.m_processor_queue.set(self.queue.qsize())
with self.m_processor_queue_in_progress.track_inprogress():
try:
await self.processor.push(data)
except Exception:
self.logger.error(
f"Error in push {self.processor.__class__.__name__}"
", continue"
)
finally:
self.queue.task_done()
if data is None:
await self.processor.flush()
break
try:
await self.processor.push(data)
except Exception:
self.logger.error(
f"Error in push {self.processor.__class__.__name__}"
", continue"
)
finally:
self.queue.task_done()
except Exception as e:
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
async def _ensure_task(self):
if self.task is None:
self.task = asyncio.get_running_loop().create_task(self.loop())
# XXX not doing a sleep here make the whole pipeline prior the thread
# to be running without having a chance to work on the task here.
await asyncio.sleep(0)
async def _push(self, data):
await self._ensure_task()
await self.queue.put(data)
async def _flush(self):

View File

@@ -0,0 +1,33 @@
from pydantic import BaseModel
from reflector.processors.base import Processor
from reflector.processors.types import DiarizationSegment
class FileDiarizationInput(BaseModel):
"""Input for file diarization containing audio URL"""
audio_url: str
class FileDiarizationOutput(BaseModel):
"""Output for file diarization containing speaker segments"""
diarization: list[DiarizationSegment]
class FileDiarizationProcessor(Processor):
"""
Diarize complete audio files from URL
"""
INPUT_TYPE = FileDiarizationInput
OUTPUT_TYPE = FileDiarizationOutput
async def _push(self, data: FileDiarizationInput):
result = await self._diarize(data)
if result:
await self.emit(result)
async def _diarize(self, data: FileDiarizationInput):
raise NotImplementedError

View File

@@ -0,0 +1,33 @@
import importlib
from reflector.processors.file_diarization import FileDiarizationProcessor
from reflector.settings import settings
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.DIARIZATION_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.file_diarization_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "DIARIZATION_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,57 @@
"""
File diarization implementation using the GPU service from modal.com
API will be a POST request to DIARIZATION_URL:
```
POST /diarize?audio_file_url=...&timestamp=0
Authorization: Bearer <modal_api_key>
```
"""
import httpx
from reflector.processors.file_diarization import (
FileDiarizationInput,
FileDiarizationOutput,
FileDiarizationProcessor,
)
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
from reflector.settings import settings
class FileDiarizationModalProcessor(FileDiarizationProcessor):
def __init__(self, modal_api_key: str | None = None, **kwargs):
super().__init__(**kwargs)
if not settings.DIARIZATION_URL:
raise Exception(
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
)
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
self.modal_api_key = modal_api_key
async def _diarize(self, data: FileDiarizationInput):
"""Get speaker diarization for file"""
self.logger.info(f"Starting diarization from {data.audio_url}")
headers = {}
if self.modal_api_key:
headers["Authorization"] = f"Bearer {self.modal_api_key}"
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
response = await client.post(
self.diarization_url,
headers=headers,
params={
"audio_file_url": data.audio_url,
"timestamp": 0,
},
)
response.raise_for_status()
diarization_data = response.json()["diarization"]
return FileDiarizationOutput(diarization=diarization_data)
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)

View File

@@ -0,0 +1,65 @@
from prometheus_client import Counter, Histogram
from reflector.processors.base import Processor
from reflector.processors.types import Transcript
class FileTranscriptInput:
"""Input for file transcription containing audio URL and language settings"""
def __init__(self, audio_url: str, language: str = "en"):
self.audio_url = audio_url
self.language = language
class FileTranscriptProcessor(Processor):
"""
Transcript complete audio files from URL
"""
INPUT_TYPE = FileTranscriptInput
OUTPUT_TYPE = Transcript
m_transcript = Histogram(
"file_transcript",
"Time spent in FileTranscript.transcript",
["backend"],
)
m_transcript_call = Counter(
"file_transcript_call",
"Number of calls to FileTranscript.transcript",
["backend"],
)
m_transcript_success = Counter(
"file_transcript_success",
"Number of successful calls to FileTranscript.transcript",
["backend"],
)
m_transcript_failure = Counter(
"file_transcript_failure",
"Number of failed calls to FileTranscript.transcript",
["backend"],
)
def __init__(self, *args, **kwargs):
name = self.__class__.__name__
self.m_transcript = self.m_transcript.labels(name)
self.m_transcript_call = self.m_transcript_call.labels(name)
self.m_transcript_success = self.m_transcript_success.labels(name)
self.m_transcript_failure = self.m_transcript_failure.labels(name)
super().__init__(*args, **kwargs)
async def _push(self, data: FileTranscriptInput):
try:
self.m_transcript_call.inc()
with self.m_transcript.time():
result = await self._transcript(data)
self.m_transcript_success.inc()
if result:
await self.emit(result)
except Exception:
self.m_transcript_failure.inc()
raise
async def _transcript(self, data: FileTranscriptInput):
raise NotImplementedError

View File

@@ -0,0 +1,32 @@
import importlib
from reflector.processors.file_transcript import FileTranscriptProcessor
from reflector.settings import settings
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.file_transcript_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "TRANSCRIPT_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,74 @@
"""
File transcription implementation using the GPU service from modal.com
API will be a POST request to TRANSCRIPT_URL:
```json
{
"audio_file_url": "https://...",
"language": "en",
"model": "parakeet-tdt-0.6b-v2",
"batch": true
}
```
"""
import httpx
from reflector.processors.file_transcript import (
FileTranscriptInput,
FileTranscriptProcessor,
)
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
from reflector.processors.types import Transcript, Word
from reflector.settings import settings
class FileTranscriptModalProcessor(FileTranscriptProcessor):
def __init__(self, modal_api_key: str | None = None, **kwargs):
super().__init__(**kwargs)
if not settings.TRANSCRIPT_URL:
raise Exception(
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
)
self.transcript_url = settings.TRANSCRIPT_URL
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
self.modal_api_key = modal_api_key
async def _transcript(self, data: FileTranscriptInput):
"""Send full file to Modal for transcription"""
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
self.logger.info(f"Starting file transcription from {data.audio_url}")
headers = {}
if self.modal_api_key:
headers["Authorization"] = f"Bearer {self.modal_api_key}"
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
response = await client.post(
url,
headers=headers,
json={
"audio_file_url": data.audio_url,
"language": data.language,
"batch": True,
},
)
response.raise_for_status()
result = response.json()
words = [
Word(
text=word_info["word"],
start=word_info["start"],
end=word_info["end"],
)
for word_info in result.get("words", [])
]
return Transcript(words=words)
# Register with the auto processor
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)

View File

@@ -0,0 +1,45 @@
"""
Processor to assemble transcript with diarization results
"""
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.base import Processor
from reflector.processors.types import DiarizationSegment, Transcript
class TranscriptDiarizationAssemblerInput:
"""Input containing transcript and diarization data"""
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
self.transcript = transcript
self.diarization = diarization
class TranscriptDiarizationAssemblerProcessor(Processor):
"""
Assemble transcript with diarization results by applying speaker assignments
"""
INPUT_TYPE = TranscriptDiarizationAssemblerInput
OUTPUT_TYPE = Transcript
async def _push(self, data: TranscriptDiarizationAssemblerInput):
result = await self._assemble(data)
if result:
await self.emit(result)
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
"""Apply diarization to transcript words"""
if not data.diarization:
self.logger.info(
"No diarization data provided, returning original transcript"
)
return data.transcript
# Reuse logic from AudioDiarizationProcessor
processor = AudioDiarizationProcessor()
words = data.transcript.words
processor.assign_speaker(words, data.diarization)
self.logger.info(f"Applied diarization to {len(words)} words")
return data.transcript

View File

@@ -2,13 +2,22 @@ import io
import re
import tempfile
from pathlib import Path
from typing import Annotated
from typing import Annotated, TypedDict
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, Field, PrivateAttr
from reflector.redis_cache import redis_cache
class DiarizationSegment(TypedDict):
"""Type definition for diarization segment containing speaker information"""
start: float
end: float
speaker: int
PUNC_RE = re.compile(r"[.;:?!…]")
profanity_filter = ProfanityFilter()