mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
294 lines
11 KiB
Python
294 lines
11 KiB
Python
from typing import Optional
|
|
|
|
import av
|
|
import numpy as np
|
|
import torch
|
|
from silero_vad import VADIterator, load_silero_vad
|
|
|
|
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
|
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
|
|
|
|
|
class AudioChunkerSileroProcessor(AudioChunkerProcessor):
|
|
"""
|
|
Assemble audio frames into chunks with VAD-based speech detection using Silero VAD.
|
|
|
|
Expects input audio to be already downscaled to 16kHz mono s16 format
|
|
(handled by AudioDownscaleProcessor in the pipeline).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
block_frames=256,
|
|
max_frames=1024,
|
|
use_onnx=True,
|
|
min_frames=2,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.block_frames = block_frames
|
|
self.max_frames = max_frames
|
|
self.min_frames = min_frames
|
|
|
|
# Initialize Silero VAD
|
|
self._init_vad(use_onnx)
|
|
|
|
def _init_vad(self, use_onnx=False):
|
|
"""Initialize Silero VAD model for 16kHz audio"""
|
|
try:
|
|
torch.set_num_threads(1)
|
|
self.vad_model = load_silero_vad(onnx=use_onnx)
|
|
# VAD expects 16kHz audio (guaranteed by AudioDownscaleProcessor)
|
|
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
|
self.logger.info("Silero VAD initialized for 16kHz audio")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
|
self.vad_model = None
|
|
self.vad_iterator = None
|
|
|
|
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
|
"""Process audio frame and return chunk when ready"""
|
|
self.frames.append(data)
|
|
|
|
# Check for speech segments every 32 frames (~1 second)
|
|
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
|
return await self._process_block()
|
|
|
|
# Safety fallback - emit if we hit max frames
|
|
elif len(self.frames) >= self.max_frames:
|
|
self.logger.warning(
|
|
f"AudioChunkerSileroProcessor: Reached max frames ({self.max_frames}), "
|
|
f"emitting first {self.max_frames // 2} frames"
|
|
)
|
|
frames_to_emit = self.frames[: self.max_frames // 2]
|
|
self.frames = self.frames[self.max_frames // 2 :]
|
|
if len(frames_to_emit) >= self.min_frames:
|
|
return frames_to_emit
|
|
else:
|
|
self.logger.debug(
|
|
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
|
f"(< {self.min_frames} minimum)"
|
|
)
|
|
|
|
return None
|
|
|
|
async def _process_block(self) -> Optional[list[av.AudioFrame]]:
|
|
# Need at least 32 frames for VAD detection (~1 second)
|
|
if len(self.frames) < 32 or self.vad_iterator is None:
|
|
return None
|
|
|
|
# Processing block with current buffer size
|
|
# print(f"Processing block: {len(self.frames)} frames in buffer")
|
|
|
|
try:
|
|
# Convert frames to numpy array for VAD
|
|
audio_array = self._frames_to_numpy(self.frames)
|
|
|
|
if audio_array is None:
|
|
# Fallback: emit all frames if conversion failed
|
|
frames_to_emit = self.frames[:]
|
|
self.frames = []
|
|
if len(frames_to_emit) >= self.min_frames:
|
|
return frames_to_emit
|
|
else:
|
|
self.logger.debug(
|
|
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
|
f"(< {self.min_frames} minimum)"
|
|
)
|
|
return None
|
|
|
|
# Find complete speech segments in the buffer
|
|
speech_end_frame = self._find_speech_segment_end(audio_array)
|
|
|
|
if speech_end_frame is None or speech_end_frame <= 0:
|
|
# No speech found but buffer is getting large
|
|
if len(self.frames) > 512:
|
|
# Check if it's all silence and can be discarded
|
|
# No speech segment found, buffer at {len(self.frames)} frames
|
|
|
|
# Could emit silence or discard old frames here
|
|
# For now, keep first 256 frames and discard older silence
|
|
if len(self.frames) > 768:
|
|
self.logger.debug(
|
|
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
|
)
|
|
self.frames = self.frames[-256:]
|
|
return None
|
|
|
|
# Calculate segment timing information
|
|
frames_to_emit = self.frames[:speech_end_frame]
|
|
|
|
# Get timing from av.AudioFrame
|
|
if frames_to_emit:
|
|
first_frame = frames_to_emit[0]
|
|
last_frame = frames_to_emit[-1]
|
|
sample_rate = first_frame.sample_rate
|
|
|
|
# Calculate duration
|
|
total_samples = sum(f.samples for f in frames_to_emit)
|
|
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
|
|
|
# Get timestamps if available
|
|
start_time = (
|
|
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
|
)
|
|
end_time = (
|
|
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
|
)
|
|
|
|
# Convert to HH:MM:SS format for logging
|
|
def format_time(seconds):
|
|
if not seconds:
|
|
return "00:00:00"
|
|
total_seconds = int(float(seconds))
|
|
hours = total_seconds // 3600
|
|
minutes = (total_seconds % 3600) // 60
|
|
secs = total_seconds % 60
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
|
|
|
start_formatted = format_time(start_time)
|
|
end_formatted = format_time(end_time)
|
|
|
|
# Keep remaining frames for next processing
|
|
remaining_after = len(self.frames) - speech_end_frame
|
|
|
|
# Single structured log line
|
|
self.logger.info(
|
|
"Speech segment found",
|
|
start=start_formatted,
|
|
end=end_formatted,
|
|
frames=speech_end_frame,
|
|
duration=round(duration_seconds, 2),
|
|
buffer_before=len(self.frames),
|
|
remaining=remaining_after,
|
|
)
|
|
|
|
# Keep remaining frames for next processing
|
|
self.frames = self.frames[speech_end_frame:]
|
|
|
|
# Filter out segments with too few frames
|
|
if len(frames_to_emit) >= self.min_frames:
|
|
return frames_to_emit
|
|
else:
|
|
self.logger.debug(
|
|
f"Ignoring segment with {len(frames_to_emit)} frames "
|
|
f"(< {self.min_frames} minimum)"
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error in VAD processing: {e}")
|
|
# Fallback to simple chunking
|
|
if len(self.frames) >= self.block_frames:
|
|
frames_to_emit = self.frames[: self.block_frames]
|
|
self.frames = self.frames[self.block_frames :]
|
|
if len(frames_to_emit) >= self.min_frames:
|
|
return frames_to_emit
|
|
else:
|
|
self.logger.debug(
|
|
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
|
f"(< {self.min_frames} minimum)"
|
|
)
|
|
|
|
return None
|
|
|
|
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
|
"""Convert av.AudioFrame list to numpy array for VAD processing
|
|
|
|
Input frames are already 16kHz mono s16 format from AudioDownscaleProcessor.
|
|
Only need to convert s16 to float32 for Silero VAD.
|
|
"""
|
|
if not frames:
|
|
return None
|
|
|
|
try:
|
|
# Concatenate all frame arrays
|
|
audio_arrays = [frame.to_ndarray().flatten() for frame in frames]
|
|
if not audio_arrays:
|
|
return None
|
|
|
|
combined_audio = np.concatenate(audio_arrays)
|
|
|
|
# Convert s16 to float32 (Silero VAD requires float32 in range [-1.0, 1.0])
|
|
# Input is guaranteed to be s16 from AudioDownscaleProcessor
|
|
return combined_audio.astype(np.float32) / 32768.0
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error converting frames to numpy: {e}")
|
|
return None
|
|
|
|
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
|
"""Find complete speech segments and return frame index at segment end"""
|
|
if self.vad_iterator is None or len(audio_array) == 0:
|
|
return None
|
|
|
|
try:
|
|
# Process audio in 512-sample windows for VAD
|
|
window_size = 512
|
|
min_silence_windows = 3 # Require 3 windows of silence after speech
|
|
|
|
# Track speech state
|
|
in_speech = False
|
|
speech_start = None
|
|
speech_end = None
|
|
silence_count = 0
|
|
|
|
for i in range(0, len(audio_array), window_size):
|
|
chunk = audio_array[i : i + window_size]
|
|
if len(chunk) < window_size:
|
|
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
|
|
|
# Detect if this window has speech
|
|
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
|
|
|
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
|
if speech_dict:
|
|
if not in_speech:
|
|
# Speech started
|
|
speech_start = i
|
|
in_speech = True
|
|
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
|
silence_count = 0 # Reset silence counter
|
|
continue
|
|
|
|
if not in_speech:
|
|
continue
|
|
|
|
# We're in speech but found silence
|
|
silence_count += 1
|
|
if silence_count < min_silence_windows:
|
|
continue
|
|
|
|
# Found end of speech segment
|
|
speech_end = i - (min_silence_windows - 1) * window_size
|
|
# Debug: print(f"Speech END at sample {speech_end}")
|
|
|
|
# Convert sample position to frame index
|
|
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
|
frame_index = speech_end // samples_per_frame
|
|
|
|
# Ensure we don't exceed buffer
|
|
frame_index = min(frame_index, len(self.frames))
|
|
return frame_index
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error finding speech segment: {e}")
|
|
return None
|
|
|
|
async def _flush(self):
|
|
frames = self.frames[:]
|
|
self.frames = []
|
|
if frames:
|
|
if len(frames) >= self.min_frames:
|
|
await self.emit(frames)
|
|
else:
|
|
self.logger.debug(
|
|
f"Ignoring flush segment with {len(frames)} frames "
|
|
f"(< {self.min_frames} minimum)"
|
|
)
|
|
|
|
|
|
AudioChunkerAutoProcessor.register("silero", AudioChunkerSileroProcessor)
|