mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Compare commits
3 Commits
mathieu/pa
...
mathieu/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| 5aed513c47 | |||
| 9265d201b5 | |||
| 52f9f533d7 |
@@ -1,5 +1,13 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.7.2](https://github.com/Monadical-SAS/reflector/compare/v0.7.1...v0.7.2) (2025-08-21)
|
||||||
|
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* docker image not loading libgomp.so.1 for torch ([#560](https://github.com/Monadical-SAS/reflector/issues/560)) ([773fccd](https://github.com/Monadical-SAS/reflector/commit/773fccd93e887c3493abc2e4a4864dddce610177))
|
||||||
|
* include shared rooms to search ([#558](https://github.com/Monadical-SAS/reflector/issues/558)) ([499eced](https://github.com/Monadical-SAS/reflector/commit/499eced3360b84fb3a90e1c8a3b554290d21adc2))
|
||||||
|
|
||||||
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ class TranscriberParakeetLive:
|
|||||||
text = output.text.strip()
|
text = output.text.strip()
|
||||||
words = [
|
words = [
|
||||||
{
|
{
|
||||||
"word": word_info["word"],
|
"word": word_info["word"] + " ",
|
||||||
"start": round(word_info["start"], 2),
|
"start": round(word_info["start"], 2),
|
||||||
"end": round(word_info["end"], 2),
|
"end": round(word_info["end"], 2),
|
||||||
}
|
}
|
||||||
@@ -213,7 +213,7 @@ class TranscriberParakeetLive:
|
|||||||
|
|
||||||
words = [
|
words = [
|
||||||
{
|
{
|
||||||
"word": word_info["word"],
|
"word": word_info["word"] + " ",
|
||||||
"start": round(word_info["start"], 2),
|
"start": round(word_info["start"], 2),
|
||||||
"end": round(word_info["end"], 2),
|
"end": round(word_info["end"], 2),
|
||||||
}
|
}
|
||||||
@@ -386,7 +386,7 @@ class TranscriberParakeetFile:
|
|||||||
text = output.text.strip()
|
text = output.text.strip()
|
||||||
words = [
|
words = [
|
||||||
{
|
{
|
||||||
"word": word_info["word"],
|
"word": word_info["word"] + " ",
|
||||||
"start": round(
|
"start": round(
|
||||||
word_info["start"] + start_time + timestamp_offset, 2
|
word_info["start"] + start_time + timestamp_offset, 2
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -1,622 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import uuid
|
|
||||||
from typing import Mapping, NewType
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import modal
|
|
||||||
|
|
||||||
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v3"
|
|
||||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
|
||||||
SAMPLERATE = 16000
|
|
||||||
UPLOADS_PATH = "/uploads"
|
|
||||||
CACHE_PATH = "/cache"
|
|
||||||
VAD_CONFIG = {
|
|
||||||
"max_segment_duration": 30.0,
|
|
||||||
"batch_max_files": 10,
|
|
||||||
"batch_max_duration": 5.0,
|
|
||||||
"min_segment_duration": 0.02,
|
|
||||||
"silence_padding": 0.5,
|
|
||||||
"window_size": 512,
|
|
||||||
}
|
|
||||||
|
|
||||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
|
||||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
|
||||||
|
|
||||||
app = modal.App("reflector-transcriber-parakeet-v3")
|
|
||||||
|
|
||||||
# Volume for caching model weights
|
|
||||||
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
|
||||||
# Volume for temporary file uploads
|
|
||||||
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
|
||||||
|
|
||||||
image = (
|
|
||||||
modal.Image.from_registry(
|
|
||||||
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
|
||||||
)
|
|
||||||
.env(
|
|
||||||
{
|
|
||||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
||||||
"HF_HOME": "/cache",
|
|
||||||
"DEBIAN_FRONTEND": "noninteractive",
|
|
||||||
"CXX": "g++",
|
|
||||||
"CC": "g++",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.apt_install("ffmpeg")
|
|
||||||
.pip_install(
|
|
||||||
"hf_transfer==0.1.9",
|
|
||||||
"huggingface_hub[hf-xet]==0.31.2",
|
|
||||||
"nemo_toolkit[asr]==2.3.0",
|
|
||||||
"cuda-python==12.8.0",
|
|
||||||
"fastapi==0.115.12",
|
|
||||||
"numpy<2",
|
|
||||||
"librosa==0.10.1",
|
|
||||||
"requests",
|
|
||||||
"silero-vad==5.1.0",
|
|
||||||
"torch",
|
|
||||||
)
|
|
||||||
.entrypoint([]) # silence chatty logs by container on start
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
url_path = parsed_url.path
|
|
||||||
|
|
||||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
||||||
if url_path.lower().endswith(f".{ext}"):
|
|
||||||
return AudioFileExtension(ext)
|
|
||||||
|
|
||||||
content_type = headers.get("content-type", "").lower()
|
|
||||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
||||||
return AudioFileExtension("mp3")
|
|
||||||
if "audio/wav" in content_type:
|
|
||||||
return AudioFileExtension("wav")
|
|
||||||
if "audio/mp4" in content_type:
|
|
||||||
return AudioFileExtension("mp4")
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported audio format for URL: {url}. "
|
|
||||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_audio_to_volume(
|
|
||||||
audio_file_url: str,
|
|
||||||
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
|
||||||
import requests
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
response = requests.head(audio_file_url, allow_redirects=True)
|
|
||||||
if response.status_code == 404:
|
|
||||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
||||||
|
|
||||||
response = requests.get(audio_file_url, allow_redirects=True)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
||||||
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
return unique_filename, audio_suffix
|
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
|
||||||
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
|
||||||
|
|
||||||
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
|
||||||
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
|
||||||
have to be of the same length
|
|
||||||
|
|
||||||
See: https://github.com/NVIDIA/NeMo/issues/8451
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
audio_duration = len(audio_array) / sample_rate
|
|
||||||
if audio_duration < 0.5:
|
|
||||||
silence_samples = int(sample_rate * 0.5)
|
|
||||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
|
||||||
return np.concatenate([audio_array, silence])
|
|
||||||
return audio_array
|
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
|
||||||
gpu="A10G",
|
|
||||||
timeout=600,
|
|
||||||
scaledown_window=300,
|
|
||||||
image=image,
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
enable_memory_snapshot=True,
|
|
||||||
experimental_options={"enable_gpu_snapshot": True},
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=10)
|
|
||||||
class TranscriberParakeetLive:
|
|
||||||
@modal.enter(snap=True)
|
|
||||||
def enter(self):
|
|
||||||
import nemo.collections.asr as nemo_asr
|
|
||||||
|
|
||||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
|
||||||
device = next(self.model.parameters()).device
|
|
||||||
print(f"Model is on device: {device}")
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_segment(
|
|
||||||
self,
|
|
||||||
filename: str,
|
|
||||||
):
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
||||||
padded_audio = pad_audio(audio_array, sample_rate)
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
with NoStdStreams():
|
|
||||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
|
||||||
|
|
||||||
text = output.text.strip()
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": word_info["word"] + " ",
|
|
||||||
"start": round(word_info["start"], 2),
|
|
||||||
"end": round(word_info["end"], 2),
|
|
||||||
}
|
|
||||||
for word_info in output.timestamp["word"]
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"text": text, "words": words}
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_batch(
|
|
||||||
self,
|
|
||||||
filenames: list[str],
|
|
||||||
):
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
audio_arrays = []
|
|
||||||
|
|
||||||
# Load all audio files with padding
|
|
||||||
for filename in filenames:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
|
||||||
|
|
||||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
||||||
padded_audio = pad_audio(audio_array, sample_rate)
|
|
||||||
audio_arrays.append(padded_audio)
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
with NoStdStreams():
|
|
||||||
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
|
||||||
|
|
||||||
# Process results for each file
|
|
||||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
|
||||||
text = output.text.strip()
|
|
||||||
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": word_info["word"] + " ",
|
|
||||||
"start": round(word_info["start"], 2),
|
|
||||||
"end": round(word_info["end"], 2),
|
|
||||||
}
|
|
||||||
for word_info in output.timestamp["word"]
|
|
||||||
]
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"filename": filename,
|
|
||||||
"text": text,
|
|
||||||
"words": words,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
# L40S class for file transcription (bigger files)
|
|
||||||
@app.cls(
|
|
||||||
gpu="L40S",
|
|
||||||
timeout=900,
|
|
||||||
image=image,
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
enable_memory_snapshot=True,
|
|
||||||
experimental_options={"enable_gpu_snapshot": True},
|
|
||||||
)
|
|
||||||
class TranscriberParakeetFile:
|
|
||||||
@modal.enter(snap=True)
|
|
||||||
def enter(self):
|
|
||||||
import nemo.collections.asr as nemo_asr
|
|
||||||
import torch
|
|
||||||
from silero_vad import load_silero_vad
|
|
||||||
|
|
||||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
|
||||||
device = next(self.model.parameters()).device
|
|
||||||
print(f"Model is on device: {device}")
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
self.vad_model = load_silero_vad(onnx=False)
|
|
||||||
print("Silero VAD initialized")
|
|
||||||
|
|
||||||
@modal.method()
|
|
||||||
def transcribe_segment(
|
|
||||||
self,
|
|
||||||
filename: str,
|
|
||||||
timestamp_offset: float = 0.0,
|
|
||||||
):
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
from silero_vad import VADIterator
|
|
||||||
|
|
||||||
def load_and_convert_audio(file_path):
|
|
||||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
||||||
return audio_array
|
|
||||||
|
|
||||||
def vad_segment_generator(audio_array):
|
|
||||||
"""Generate speech segments using VAD with start/end sample indices"""
|
|
||||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
|
||||||
window_size = VAD_CONFIG["window_size"]
|
|
||||||
start = None
|
|
||||||
|
|
||||||
for i in range(0, len(audio_array), window_size):
|
|
||||||
chunk = audio_array[i : i + window_size]
|
|
||||||
if len(chunk) < window_size:
|
|
||||||
chunk = np.pad(
|
|
||||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
|
||||||
)
|
|
||||||
|
|
||||||
speech_dict = vad_iterator(chunk)
|
|
||||||
if not speech_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "start" in speech_dict:
|
|
||||||
start = speech_dict["start"]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "end" in speech_dict and start is not None:
|
|
||||||
end = speech_dict["end"]
|
|
||||||
start_time = start / float(SAMPLERATE)
|
|
||||||
end_time = end / float(SAMPLERATE)
|
|
||||||
|
|
||||||
# Extract the actual audio segment
|
|
||||||
audio_segment = audio_array[start:end]
|
|
||||||
|
|
||||||
yield (start_time, end_time, audio_segment)
|
|
||||||
start = None
|
|
||||||
|
|
||||||
vad_iterator.reset_states()
|
|
||||||
|
|
||||||
def vad_segment_filter(segments):
|
|
||||||
"""Filter VAD segments by duration and chunk large segments"""
|
|
||||||
min_dur = VAD_CONFIG["min_segment_duration"]
|
|
||||||
max_dur = VAD_CONFIG["max_segment_duration"]
|
|
||||||
|
|
||||||
for start_time, end_time, audio_segment in segments:
|
|
||||||
segment_duration = end_time - start_time
|
|
||||||
|
|
||||||
# Skip very small segments
|
|
||||||
if segment_duration < min_dur:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If segment is within max duration, yield as-is
|
|
||||||
if segment_duration <= max_dur:
|
|
||||||
yield (start_time, end_time, audio_segment)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Chunk large segments into smaller pieces
|
|
||||||
chunk_samples = int(max_dur * SAMPLERATE)
|
|
||||||
current_start = start_time
|
|
||||||
|
|
||||||
for chunk_offset in range(0, len(audio_segment), chunk_samples):
|
|
||||||
chunk_audio = audio_segment[
|
|
||||||
chunk_offset : chunk_offset + chunk_samples
|
|
||||||
]
|
|
||||||
if len(chunk_audio) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
|
|
||||||
chunk_end = current_start + chunk_duration
|
|
||||||
|
|
||||||
# Only yield chunks that meet minimum duration
|
|
||||||
if chunk_duration >= min_dur:
|
|
||||||
yield (current_start, chunk_end, chunk_audio)
|
|
||||||
|
|
||||||
current_start = chunk_end
|
|
||||||
|
|
||||||
def batch_segments(segments, max_files=10, max_duration=5.0):
|
|
||||||
batch = []
|
|
||||||
batch_duration = 0.0
|
|
||||||
|
|
||||||
for start_time, end_time, audio_segment in segments:
|
|
||||||
segment_duration = end_time - start_time
|
|
||||||
|
|
||||||
if segment_duration < VAD_CONFIG["silence_padding"]:
|
|
||||||
silence_samples = int(
|
|
||||||
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
|
|
||||||
)
|
|
||||||
padding = np.zeros(silence_samples, dtype=np.float32)
|
|
||||||
audio_segment = np.concatenate([audio_segment, padding])
|
|
||||||
segment_duration = VAD_CONFIG["silence_padding"]
|
|
||||||
|
|
||||||
batch.append((start_time, end_time, audio_segment))
|
|
||||||
batch_duration += segment_duration
|
|
||||||
|
|
||||||
if len(batch) >= max_files or batch_duration >= max_duration:
|
|
||||||
yield batch
|
|
||||||
batch = []
|
|
||||||
batch_duration = 0.0
|
|
||||||
|
|
||||||
if batch:
|
|
||||||
yield batch
|
|
||||||
|
|
||||||
def transcribe_batch(model, audio_segments):
|
|
||||||
with NoStdStreams():
|
|
||||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def emit_results(
|
|
||||||
results,
|
|
||||||
segments_info,
|
|
||||||
batch_index,
|
|
||||||
total_batches,
|
|
||||||
):
|
|
||||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
|
||||||
for i, (output, (start_time, end_time, _)) in enumerate(
|
|
||||||
zip(results, segments_info)
|
|
||||||
):
|
|
||||||
text = output.text.strip()
|
|
||||||
words = [
|
|
||||||
{
|
|
||||||
"word": word_info["word"],
|
|
||||||
"start": round(
|
|
||||||
word_info["start"] + start_time + timestamp_offset, 2
|
|
||||||
),
|
|
||||||
"end": round(
|
|
||||||
word_info["end"] + start_time + timestamp_offset, 2
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for word_info in output.timestamp["word"]
|
|
||||||
]
|
|
||||||
|
|
||||||
yield text, words
|
|
||||||
|
|
||||||
upload_volume.reload()
|
|
||||||
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
audio_array = load_and_convert_audio(file_path)
|
|
||||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
|
||||||
processed_duration = 0.0
|
|
||||||
|
|
||||||
all_text_parts = []
|
|
||||||
all_words = []
|
|
||||||
|
|
||||||
raw_segments = vad_segment_generator(audio_array)
|
|
||||||
filtered_segments = vad_segment_filter(raw_segments)
|
|
||||||
batches = batch_segments(
|
|
||||||
filtered_segments,
|
|
||||||
VAD_CONFIG["batch_max_files"],
|
|
||||||
VAD_CONFIG["batch_max_duration"],
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_index = 0
|
|
||||||
total_batches = max(
|
|
||||||
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch in batches:
|
|
||||||
batch_index += 1
|
|
||||||
audio_segments = [seg[2] for seg in batch]
|
|
||||||
results = transcribe_batch(self.model, audio_segments)
|
|
||||||
|
|
||||||
for text, words in emit_results(
|
|
||||||
results,
|
|
||||||
batch,
|
|
||||||
batch_index,
|
|
||||||
total_batches,
|
|
||||||
):
|
|
||||||
if not text:
|
|
||||||
continue
|
|
||||||
all_text_parts.append(text)
|
|
||||||
all_words.extend(words)
|
|
||||||
|
|
||||||
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
|
|
||||||
|
|
||||||
combined_text = " ".join(all_text_parts)
|
|
||||||
return {"text": combined_text, "words": all_words}
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
|
||||||
scaledown_window=60,
|
|
||||||
timeout=600,
|
|
||||||
secrets=[
|
|
||||||
modal.Secret.from_name("reflector-gpu"),
|
|
||||||
],
|
|
||||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
||||||
image=image,
|
|
||||||
)
|
|
||||||
@modal.concurrent(max_inputs=40)
|
|
||||||
@modal.asgi_app()
|
|
||||||
def web():
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import (
|
|
||||||
Body,
|
|
||||||
Depends,
|
|
||||||
FastAPI,
|
|
||||||
Form,
|
|
||||||
HTTPException,
|
|
||||||
UploadFile,
|
|
||||||
status,
|
|
||||||
)
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
transcriber_live = TranscriberParakeetLive()
|
|
||||||
transcriber_file = TranscriberParakeetFile()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
||||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
||||||
return
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
class TranscriptResponse(BaseModel):
|
|
||||||
result: dict
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
|
||||||
def transcribe(
|
|
||||||
file: UploadFile = None,
|
|
||||||
files: list[UploadFile] | None = None,
|
|
||||||
model: str = Form(MODEL_NAME),
|
|
||||||
language: str = Form("en"),
|
|
||||||
batch: bool = Form(False),
|
|
||||||
):
|
|
||||||
# Parakeet only supports English
|
|
||||||
if language != "en":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
|
||||||
)
|
|
||||||
# Handle both single file and multiple files
|
|
||||||
if not file and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
|
||||||
)
|
|
||||||
if batch and not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="Batch transcription requires 'files'"
|
|
||||||
)
|
|
||||||
|
|
||||||
upload_files = [file] if file else files
|
|
||||||
|
|
||||||
# Upload files to volume
|
|
||||||
uploaded_filenames = []
|
|
||||||
for upload_file in upload_files:
|
|
||||||
audio_suffix = upload_file.filename.split(".")[-1]
|
|
||||||
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
|
||||||
|
|
||||||
# Generate unique filename
|
|
||||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
|
|
||||||
print(f"Writing file to: {file_path}")
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
content = upload_file.file.read()
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
uploaded_filenames.append(unique_filename)
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use A10G live transcriber for per-file transcription
|
|
||||||
if batch and len(upload_files) > 1:
|
|
||||||
# Use batch transcription
|
|
||||||
func = transcriber_live.transcribe_batch.spawn(
|
|
||||||
filenames=uploaded_filenames,
|
|
||||||
)
|
|
||||||
results = func.get()
|
|
||||||
return {"results": results}
|
|
||||||
|
|
||||||
# Per-file transcription
|
|
||||||
results = []
|
|
||||||
for filename in uploaded_filenames:
|
|
||||||
func = transcriber_live.transcribe_segment.spawn(
|
|
||||||
filename=filename,
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
result["filename"] = filename
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return {"results": results} if len(results) > 1 else results[0]
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for filename in uploaded_filenames:
|
|
||||||
try:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
||||||
print(f"Deleting file: {file_path}")
|
|
||||||
os.remove(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error deleting {filename}: {e}")
|
|
||||||
|
|
||||||
upload_volume.commit()
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
|
||||||
def transcribe_from_url(
|
|
||||||
audio_file_url: str = Body(
|
|
||||||
..., description="URL of the audio file to transcribe"
|
|
||||||
),
|
|
||||||
model: str = Body(MODEL_NAME),
|
|
||||||
language: str = Body("en", description="Language code (only 'en' supported)"),
|
|
||||||
timestamp_offset: float = Body(0.0),
|
|
||||||
):
|
|
||||||
# Parakeet only supports English
|
|
||||||
if language != "en":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
|
||||||
)
|
|
||||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
|
||||||
|
|
||||||
try:
|
|
||||||
func = transcriber_file.transcribe_segment.spawn(
|
|
||||||
filename=unique_filename,
|
|
||||||
timestamp_offset=timestamp_offset,
|
|
||||||
)
|
|
||||||
result = func.get()
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
||||||
print(f"Deleting file: {file_path}")
|
|
||||||
os.remove(file_path)
|
|
||||||
upload_volume.commit()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error cleaning up {unique_filename}: {e}")
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
class NoStdStreams:
|
|
||||||
def __init__(self):
|
|
||||||
self.devnull = open(os.devnull, "w")
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
|
||||||
self._stdout.flush()
|
|
||||||
self._stderr.flush()
|
|
||||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
|
||||||
self.devnull.close()
|
|
||||||
@@ -40,8 +40,9 @@ from reflector.db.transcripts import (
|
|||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
|
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerAutoProcessor,
|
||||||
AudioDiarizationAutoProcessor,
|
AudioDiarizationAutoProcessor,
|
||||||
|
AudioDownscaleProcessor,
|
||||||
AudioFileWriterProcessor,
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
@@ -365,7 +366,8 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
path=transcript.audio_wav_filename,
|
path=transcript.audio_wav_filename,
|
||||||
on_duration=self.on_duration,
|
on_duration=self.on_duration,
|
||||||
),
|
),
|
||||||
AudioChunkerProcessor(),
|
AudioDownscaleProcessor(),
|
||||||
|
AudioChunkerAutoProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||||
|
from .audio_chunker_auto import AudioChunkerAutoProcessor # noqa: F401
|
||||||
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||||
|
from .audio_downscale import AudioDownscaleProcessor # noqa: F401
|
||||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||||
|
|||||||
@@ -1,340 +1,78 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import numpy as np
|
from prometheus_client import Counter, Histogram
|
||||||
import torch
|
|
||||||
from silero_vad import VADIterator, load_silero_vad
|
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
|
|
||||||
|
|
||||||
class AudioChunkerProcessor(Processor):
|
class AudioChunkerProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Assemble audio frames into chunks with VAD-based speech detection
|
Base class for assembling audio frames into chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INPUT_TYPE = av.AudioFrame
|
INPUT_TYPE = av.AudioFrame
|
||||||
OUTPUT_TYPE = list[av.AudioFrame]
|
OUTPUT_TYPE = list[av.AudioFrame]
|
||||||
|
|
||||||
def __init__(
|
m_chunk = Histogram(
|
||||||
self,
|
"audio_chunker",
|
||||||
block_frames=256,
|
"Time spent in AudioChunker.chunk",
|
||||||
max_frames=1024,
|
["backend"],
|
||||||
vad_threshold=0.5,
|
)
|
||||||
use_onnx=False,
|
m_chunk_call = Counter(
|
||||||
min_frames=2,
|
"audio_chunker_call",
|
||||||
):
|
"Number of calls to AudioChunker.chunk",
|
||||||
super().__init__()
|
["backend"],
|
||||||
|
)
|
||||||
|
m_chunk_success = Counter(
|
||||||
|
"audio_chunker_success",
|
||||||
|
"Number of successful calls to AudioChunker.chunk",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
m_chunk_failure = Counter(
|
||||||
|
"audio_chunker_failure",
|
||||||
|
"Number of failed calls to AudioChunker.chunk",
|
||||||
|
["backend"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
name = self.__class__.__name__
|
||||||
|
self.m_chunk = self.m_chunk.labels(name)
|
||||||
|
self.m_chunk_call = self.m_chunk_call.labels(name)
|
||||||
|
self.m_chunk_success = self.m_chunk_success.labels(name)
|
||||||
|
self.m_chunk_failure = self.m_chunk_failure.labels(name)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
self.frames: list[av.AudioFrame] = []
|
self.frames: list[av.AudioFrame] = []
|
||||||
self.block_frames = block_frames
|
|
||||||
self.max_frames = max_frames
|
|
||||||
self.vad_threshold = vad_threshold
|
|
||||||
self.min_frames = min_frames
|
|
||||||
|
|
||||||
# Initialize Silero VAD
|
|
||||||
self._init_vad(use_onnx)
|
|
||||||
|
|
||||||
def _init_vad(self, use_onnx=False):
|
|
||||||
"""Initialize Silero VAD model"""
|
|
||||||
try:
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
self.vad_model = load_silero_vad(onnx=use_onnx)
|
|
||||||
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
|
||||||
self.logger.info("Silero VAD initialized successfully")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
|
||||||
self.vad_model = None
|
|
||||||
self.vad_iterator = None
|
|
||||||
|
|
||||||
async def _push(self, data: av.AudioFrame):
|
async def _push(self, data: av.AudioFrame):
|
||||||
self.frames.append(data)
|
"""Process incoming audio frame"""
|
||||||
# print("timestamp", data.pts * data.time_base * 1000)
|
# Validate audio format on first frame
|
||||||
|
if len(self.frames) == 0:
|
||||||
# Check for speech segments every 32 frames (~1 second)
|
if data.sample_rate != 16000 or len(data.layout.channels) != 1:
|
||||||
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
raise ValueError(
|
||||||
await self._process_block()
|
f"AudioChunkerProcessor expects 16kHz mono audio, got {data.sample_rate}Hz "
|
||||||
|
f"with {len(data.layout.channels)} channel(s). "
|
||||||
# Safety fallback - emit if we hit max frames
|
f"Use AudioDownscaleProcessor before this processor."
|
||||||
elif len(self.frames) >= self.max_frames:
|
|
||||||
self.logger.warning(
|
|
||||||
f"AudioChunkerProcessor: Reached max frames ({self.max_frames}), "
|
|
||||||
f"emitting first {self.max_frames // 2} frames"
|
|
||||||
)
|
|
||||||
frames_to_emit = self.frames[: self.max_frames // 2]
|
|
||||||
self.frames = self.frames[self.max_frames // 2 :]
|
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
|
||||||
await self.emit(frames_to_emit)
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_block(self):
|
|
||||||
# Need at least 32 frames for VAD detection (~1 second)
|
|
||||||
if len(self.frames) < 32 or self.vad_iterator is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Processing block with current buffer size
|
|
||||||
# print(f"Processing block: {len(self.frames)} frames in buffer")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert frames to numpy array for VAD
|
self.m_chunk_call.inc()
|
||||||
audio_array = self._frames_to_numpy(self.frames)
|
with self.m_chunk.time():
|
||||||
|
result = await self._chunk(data)
|
||||||
|
self.m_chunk_success.inc()
|
||||||
|
if result:
|
||||||
|
await self.emit(result)
|
||||||
|
except Exception:
|
||||||
|
self.m_chunk_failure.inc()
|
||||||
|
raise
|
||||||
|
|
||||||
if audio_array is None:
|
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||||
# Fallback: emit all frames if conversion failed
|
"""
|
||||||
frames_to_emit = self.frames[:]
|
Process audio frame and return chunk when ready.
|
||||||
self.frames = []
|
Subclasses should implement their chunking logic here.
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
"""
|
||||||
await self.emit(frames_to_emit)
|
raise NotImplementedError
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find complete speech segments in the buffer
|
|
||||||
speech_end_frame = self._find_speech_segment_end(audio_array)
|
|
||||||
|
|
||||||
if speech_end_frame is None or speech_end_frame <= 0:
|
|
||||||
# No speech found but buffer is getting large
|
|
||||||
if len(self.frames) > 512:
|
|
||||||
# Check if it's all silence and can be discarded
|
|
||||||
# No speech segment found, buffer at {len(self.frames)} frames
|
|
||||||
|
|
||||||
# Could emit silence or discard old frames here
|
|
||||||
# For now, keep first 256 frames and discard older silence
|
|
||||||
if len(self.frames) > 768:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
|
||||||
)
|
|
||||||
self.frames = self.frames[-256:]
|
|
||||||
return
|
|
||||||
|
|
||||||
# Calculate segment timing information
|
|
||||||
frames_to_emit = self.frames[:speech_end_frame]
|
|
||||||
|
|
||||||
# Get timing from av.AudioFrame
|
|
||||||
if frames_to_emit:
|
|
||||||
first_frame = frames_to_emit[0]
|
|
||||||
last_frame = frames_to_emit[-1]
|
|
||||||
sample_rate = first_frame.sample_rate
|
|
||||||
|
|
||||||
# Calculate duration
|
|
||||||
total_samples = sum(f.samples for f in frames_to_emit)
|
|
||||||
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
|
||||||
|
|
||||||
# Get timestamps if available
|
|
||||||
start_time = (
|
|
||||||
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
|
||||||
)
|
|
||||||
end_time = (
|
|
||||||
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to HH:MM:SS format for logging
|
|
||||||
def format_time(seconds):
|
|
||||||
if not seconds:
|
|
||||||
return "00:00:00"
|
|
||||||
total_seconds = int(float(seconds))
|
|
||||||
hours = total_seconds // 3600
|
|
||||||
minutes = (total_seconds % 3600) // 60
|
|
||||||
secs = total_seconds % 60
|
|
||||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
|
||||||
|
|
||||||
start_formatted = format_time(start_time)
|
|
||||||
end_formatted = format_time(end_time)
|
|
||||||
|
|
||||||
# Keep remaining frames for next processing
|
|
||||||
remaining_after = len(self.frames) - speech_end_frame
|
|
||||||
|
|
||||||
# Single structured log line
|
|
||||||
self.logger.info(
|
|
||||||
"Speech segment found",
|
|
||||||
start=start_formatted,
|
|
||||||
end=end_formatted,
|
|
||||||
frames=speech_end_frame,
|
|
||||||
duration=round(duration_seconds, 2),
|
|
||||||
buffer_before=len(self.frames),
|
|
||||||
remaining=remaining_after,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep remaining frames for next processing
|
|
||||||
self.frames = self.frames[speech_end_frame:]
|
|
||||||
|
|
||||||
# Filter out segments with too few frames
|
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
|
||||||
await self.emit(frames_to_emit)
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error in VAD processing: {e}")
|
|
||||||
# Fallback to simple chunking
|
|
||||||
if len(self.frames) >= self.block_frames:
|
|
||||||
frames_to_emit = self.frames[: self.block_frames]
|
|
||||||
self.frames = self.frames[self.block_frames :]
|
|
||||||
if len(frames_to_emit) >= self.min_frames:
|
|
||||||
await self.emit(frames_to_emit)
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
|
||||||
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
|
||||||
if not frames:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
first_frame = frames[0]
|
|
||||||
original_sample_rate = first_frame.sample_rate
|
|
||||||
|
|
||||||
audio_data = []
|
|
||||||
for frame in frames:
|
|
||||||
frame_array = frame.to_ndarray()
|
|
||||||
|
|
||||||
# Handle stereo -> mono conversion
|
|
||||||
if len(frame_array.shape) == 2 and frame_array.shape[0] > 1:
|
|
||||||
frame_array = np.mean(frame_array, axis=0)
|
|
||||||
elif len(frame_array.shape) == 2:
|
|
||||||
frame_array = frame_array.flatten()
|
|
||||||
|
|
||||||
audio_data.append(frame_array)
|
|
||||||
|
|
||||||
if not audio_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
combined_audio = np.concatenate(audio_data)
|
|
||||||
|
|
||||||
# Resample from 48kHz to 16kHz if needed
|
|
||||||
if original_sample_rate != 16000:
|
|
||||||
combined_audio = self._resample_audio(
|
|
||||||
combined_audio, original_sample_rate, 16000
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure float32 format
|
|
||||||
if combined_audio.dtype == np.int16:
|
|
||||||
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
|
||||||
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
|
||||||
elif combined_audio.dtype != np.float32:
|
|
||||||
combined_audio = combined_audio.astype(np.float32)
|
|
||||||
|
|
||||||
return combined_audio
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error converting frames to numpy: {e}")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _resample_audio(
|
|
||||||
self, audio: np.ndarray, from_sr: int, to_sr: int
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Simple linear resampling from from_sr to to_sr"""
|
|
||||||
if from_sr == to_sr:
|
|
||||||
return audio
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Simple linear interpolation resampling
|
|
||||||
ratio = to_sr / from_sr
|
|
||||||
new_length = int(len(audio) * ratio)
|
|
||||||
|
|
||||||
# Create indices for interpolation
|
|
||||||
old_indices = np.linspace(0, len(audio) - 1, new_length)
|
|
||||||
resampled = np.interp(old_indices, np.arange(len(audio)), audio)
|
|
||||||
|
|
||||||
return resampled.astype(np.float32)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error("Resampling error", exc_info=e)
|
|
||||||
# Fallback: simple decimation/repetition
|
|
||||||
if from_sr > to_sr:
|
|
||||||
# Downsample by taking every nth sample
|
|
||||||
step = from_sr // to_sr
|
|
||||||
return audio[::step]
|
|
||||||
else:
|
|
||||||
# Upsample by repeating samples
|
|
||||||
repeat = to_sr // from_sr
|
|
||||||
return np.repeat(audio, repeat)
|
|
||||||
|
|
||||||
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
|
||||||
"""Find complete speech segments and return frame index at segment end"""
|
|
||||||
if self.vad_iterator is None or len(audio_array) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Process audio in 512-sample windows for VAD
|
|
||||||
window_size = 512
|
|
||||||
min_silence_windows = 3 # Require 3 windows of silence after speech
|
|
||||||
|
|
||||||
# Track speech state
|
|
||||||
in_speech = False
|
|
||||||
speech_start = None
|
|
||||||
speech_end = None
|
|
||||||
silence_count = 0
|
|
||||||
|
|
||||||
for i in range(0, len(audio_array), window_size):
|
|
||||||
chunk = audio_array[i : i + window_size]
|
|
||||||
if len(chunk) < window_size:
|
|
||||||
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
|
||||||
|
|
||||||
# Detect if this window has speech
|
|
||||||
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
|
||||||
|
|
||||||
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
|
||||||
if speech_dict:
|
|
||||||
if not in_speech:
|
|
||||||
# Speech started
|
|
||||||
speech_start = i
|
|
||||||
in_speech = True
|
|
||||||
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
|
||||||
silence_count = 0 # Reset silence counter
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not in_speech:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# We're in speech but found silence
|
|
||||||
silence_count += 1
|
|
||||||
if silence_count < min_silence_windows:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Found end of speech segment
|
|
||||||
speech_end = i - (min_silence_windows - 1) * window_size
|
|
||||||
# Debug: print(f"Speech END at sample {speech_end}")
|
|
||||||
|
|
||||||
# Convert sample position to frame index
|
|
||||||
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
|
||||||
# Account for resampling: we process at 16kHz but frames might be 48kHz
|
|
||||||
resample_ratio = 48000 / 16000 # 3x
|
|
||||||
actual_sample_pos = int(speech_end * resample_ratio)
|
|
||||||
frame_index = actual_sample_pos // samples_per_frame
|
|
||||||
|
|
||||||
# Ensure we don't exceed buffer
|
|
||||||
frame_index = min(frame_index, len(self.frames))
|
|
||||||
return frame_index
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error finding speech segment: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _flush(self):
|
async def _flush(self):
|
||||||
frames = self.frames[:]
|
"""Flush any remaining frames when processing ends"""
|
||||||
self.frames = []
|
raise NotImplementedError
|
||||||
if frames:
|
|
||||||
if len(frames) >= self.min_frames:
|
|
||||||
await self.emit(frames)
|
|
||||||
else:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Ignoring flush segment with {len(frames)} frames "
|
|
||||||
f"(< {self.min_frames} minimum)"
|
|
||||||
)
|
|
||||||
|
|||||||
32
server/reflector/processors/audio_chunker_auto.py
Normal file
32
server/reflector/processors/audio_chunker_auto.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class AudioChunkerAutoProcessor(AudioChunkerProcessor):
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, kclass):
|
||||||
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
|
def __new__(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.AUDIO_CHUNKER_BACKEND
|
||||||
|
if name not in cls._registry:
|
||||||
|
module_name = f"reflector.processors.audio_chunker_{name}"
|
||||||
|
importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# gather specific configuration for the processor
|
||||||
|
# search `AUDIO_CHUNKER_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||||
|
config = {}
|
||||||
|
name_upper = name.upper()
|
||||||
|
settings_prefix = "AUDIO_CHUNKER_"
|
||||||
|
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||||
|
for key, value in settings:
|
||||||
|
if key.startswith(config_prefix):
|
||||||
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
|
config[config_name] = value
|
||||||
|
|
||||||
|
return cls._registry[name](**config | kwargs)
|
||||||
34
server/reflector/processors/audio_chunker_frames.py
Normal file
34
server/reflector/processors/audio_chunker_frames.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
|
|
||||||
|
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||||
|
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class AudioChunkerFramesProcessor(AudioChunkerProcessor):
|
||||||
|
"""
|
||||||
|
Simple frame-based audio chunker that emits chunks after a fixed number of frames
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_frames=256, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.max_frames = max_frames
|
||||||
|
|
||||||
|
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||||
|
self.frames.append(data)
|
||||||
|
if len(self.frames) >= self.max_frames:
|
||||||
|
frames_to_emit = self.frames[:]
|
||||||
|
self.frames = []
|
||||||
|
return frames_to_emit
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _flush(self):
|
||||||
|
frames = self.frames[:]
|
||||||
|
self.frames = []
|
||||||
|
if frames:
|
||||||
|
await self.emit(frames)
|
||||||
|
|
||||||
|
|
||||||
|
AudioChunkerAutoProcessor.register("frames", AudioChunkerFramesProcessor)
|
||||||
293
server/reflector/processors/audio_chunker_silero.py
Normal file
293
server/reflector/processors/audio_chunker_silero.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from silero_vad import VADIterator, load_silero_vad
|
||||||
|
|
||||||
|
from reflector.processors.audio_chunker import AudioChunkerProcessor
|
||||||
|
from reflector.processors.audio_chunker_auto import AudioChunkerAutoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class AudioChunkerSileroProcessor(AudioChunkerProcessor):
|
||||||
|
"""
|
||||||
|
Assemble audio frames into chunks with VAD-based speech detection using Silero VAD.
|
||||||
|
|
||||||
|
Expects input audio to be already downscaled to 16kHz mono s16 format
|
||||||
|
(handled by AudioDownscaleProcessor in the pipeline).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_frames=256,
|
||||||
|
max_frames=1024,
|
||||||
|
use_onnx=True,
|
||||||
|
min_frames=2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.block_frames = block_frames
|
||||||
|
self.max_frames = max_frames
|
||||||
|
self.min_frames = min_frames
|
||||||
|
|
||||||
|
# Initialize Silero VAD
|
||||||
|
self._init_vad(use_onnx)
|
||||||
|
|
||||||
|
def _init_vad(self, use_onnx=False):
|
||||||
|
"""Initialize Silero VAD model for 16kHz audio"""
|
||||||
|
try:
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
self.vad_model = load_silero_vad(onnx=use_onnx)
|
||||||
|
# VAD expects 16kHz audio (guaranteed by AudioDownscaleProcessor)
|
||||||
|
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
||||||
|
self.logger.info("Silero VAD initialized for 16kHz audio")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
||||||
|
self.vad_model = None
|
||||||
|
self.vad_iterator = None
|
||||||
|
|
||||||
|
async def _chunk(self, data: av.AudioFrame) -> Optional[list[av.AudioFrame]]:
|
||||||
|
"""Process audio frame and return chunk when ready"""
|
||||||
|
self.frames.append(data)
|
||||||
|
|
||||||
|
# Check for speech segments every 32 frames (~1 second)
|
||||||
|
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
||||||
|
return await self._process_block()
|
||||||
|
|
||||||
|
# Safety fallback - emit if we hit max frames
|
||||||
|
elif len(self.frames) >= self.max_frames:
|
||||||
|
self.logger.warning(
|
||||||
|
f"AudioChunkerSileroProcessor: Reached max frames ({self.max_frames}), "
|
||||||
|
f"emitting first {self.max_frames // 2} frames"
|
||||||
|
)
|
||||||
|
frames_to_emit = self.frames[: self.max_frames // 2]
|
||||||
|
self.frames = self.frames[self.max_frames // 2 :]
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
return frames_to_emit
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _process_block(self) -> Optional[list[av.AudioFrame]]:
|
||||||
|
# Need at least 32 frames for VAD detection (~1 second)
|
||||||
|
if len(self.frames) < 32 or self.vad_iterator is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Processing block with current buffer size
|
||||||
|
# print(f"Processing block: {len(self.frames)} frames in buffer")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert frames to numpy array for VAD
|
||||||
|
audio_array = self._frames_to_numpy(self.frames)
|
||||||
|
|
||||||
|
if audio_array is None:
|
||||||
|
# Fallback: emit all frames if conversion failed
|
||||||
|
frames_to_emit = self.frames[:]
|
||||||
|
self.frames = []
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
return frames_to_emit
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find complete speech segments in the buffer
|
||||||
|
speech_end_frame = self._find_speech_segment_end(audio_array)
|
||||||
|
|
||||||
|
if speech_end_frame is None or speech_end_frame <= 0:
|
||||||
|
# No speech found but buffer is getting large
|
||||||
|
if len(self.frames) > 512:
|
||||||
|
# Check if it's all silence and can be discarded
|
||||||
|
# No speech segment found, buffer at {len(self.frames)} frames
|
||||||
|
|
||||||
|
# Could emit silence or discard old frames here
|
||||||
|
# For now, keep first 256 frames and discard older silence
|
||||||
|
if len(self.frames) > 768:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
||||||
|
)
|
||||||
|
self.frames = self.frames[-256:]
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Calculate segment timing information
|
||||||
|
frames_to_emit = self.frames[:speech_end_frame]
|
||||||
|
|
||||||
|
# Get timing from av.AudioFrame
|
||||||
|
if frames_to_emit:
|
||||||
|
first_frame = frames_to_emit[0]
|
||||||
|
last_frame = frames_to_emit[-1]
|
||||||
|
sample_rate = first_frame.sample_rate
|
||||||
|
|
||||||
|
# Calculate duration
|
||||||
|
total_samples = sum(f.samples for f in frames_to_emit)
|
||||||
|
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
||||||
|
|
||||||
|
# Get timestamps if available
|
||||||
|
start_time = (
|
||||||
|
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
||||||
|
)
|
||||||
|
end_time = (
|
||||||
|
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to HH:MM:SS format for logging
|
||||||
|
def format_time(seconds):
|
||||||
|
if not seconds:
|
||||||
|
return "00:00:00"
|
||||||
|
total_seconds = int(float(seconds))
|
||||||
|
hours = total_seconds // 3600
|
||||||
|
minutes = (total_seconds % 3600) // 60
|
||||||
|
secs = total_seconds % 60
|
||||||
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||||
|
|
||||||
|
start_formatted = format_time(start_time)
|
||||||
|
end_formatted = format_time(end_time)
|
||||||
|
|
||||||
|
# Keep remaining frames for next processing
|
||||||
|
remaining_after = len(self.frames) - speech_end_frame
|
||||||
|
|
||||||
|
# Single structured log line
|
||||||
|
self.logger.info(
|
||||||
|
"Speech segment found",
|
||||||
|
start=start_formatted,
|
||||||
|
end=end_formatted,
|
||||||
|
frames=speech_end_frame,
|
||||||
|
duration=round(duration_seconds, 2),
|
||||||
|
buffer_before=len(self.frames),
|
||||||
|
remaining=remaining_after,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep remaining frames for next processing
|
||||||
|
self.frames = self.frames[speech_end_frame:]
|
||||||
|
|
||||||
|
# Filter out segments with too few frames
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
return frames_to_emit
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in VAD processing: {e}")
|
||||||
|
# Fallback to simple chunking
|
||||||
|
if len(self.frames) >= self.block_frames:
|
||||||
|
frames_to_emit = self.frames[: self.block_frames]
|
||||||
|
self.frames = self.frames[self.block_frames :]
|
||||||
|
if len(frames_to_emit) >= self.min_frames:
|
||||||
|
return frames_to_emit
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
||||||
|
"""Convert av.AudioFrame list to numpy array for VAD processing
|
||||||
|
|
||||||
|
Input frames are already 16kHz mono s16 format from AudioDownscaleProcessor.
|
||||||
|
Only need to convert s16 to float32 for Silero VAD.
|
||||||
|
"""
|
||||||
|
if not frames:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Concatenate all frame arrays
|
||||||
|
audio_arrays = [frame.to_ndarray().flatten() for frame in frames]
|
||||||
|
if not audio_arrays:
|
||||||
|
return None
|
||||||
|
|
||||||
|
combined_audio = np.concatenate(audio_arrays)
|
||||||
|
|
||||||
|
# Convert s16 to float32 (Silero VAD requires float32 in range [-1.0, 1.0])
|
||||||
|
# Input is guaranteed to be s16 from AudioDownscaleProcessor
|
||||||
|
return combined_audio.astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error converting frames to numpy: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
||||||
|
"""Find complete speech segments and return frame index at segment end"""
|
||||||
|
if self.vad_iterator is None or len(audio_array) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process audio in 512-sample windows for VAD
|
||||||
|
window_size = 512
|
||||||
|
min_silence_windows = 3 # Require 3 windows of silence after speech
|
||||||
|
|
||||||
|
# Track speech state
|
||||||
|
in_speech = False
|
||||||
|
speech_start = None
|
||||||
|
speech_end = None
|
||||||
|
silence_count = 0
|
||||||
|
|
||||||
|
for i in range(0, len(audio_array), window_size):
|
||||||
|
chunk = audio_array[i : i + window_size]
|
||||||
|
if len(chunk) < window_size:
|
||||||
|
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
||||||
|
|
||||||
|
# Detect if this window has speech
|
||||||
|
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
||||||
|
|
||||||
|
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
||||||
|
if speech_dict:
|
||||||
|
if not in_speech:
|
||||||
|
# Speech started
|
||||||
|
speech_start = i
|
||||||
|
in_speech = True
|
||||||
|
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
||||||
|
silence_count = 0 # Reset silence counter
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not in_speech:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We're in speech but found silence
|
||||||
|
silence_count += 1
|
||||||
|
if silence_count < min_silence_windows:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Found end of speech segment
|
||||||
|
speech_end = i - (min_silence_windows - 1) * window_size
|
||||||
|
# Debug: print(f"Speech END at sample {speech_end}")
|
||||||
|
|
||||||
|
# Convert sample position to frame index
|
||||||
|
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
||||||
|
frame_index = speech_end // samples_per_frame
|
||||||
|
|
||||||
|
# Ensure we don't exceed buffer
|
||||||
|
frame_index = min(frame_index, len(self.frames))
|
||||||
|
return frame_index
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error finding speech segment: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _flush(self):
|
||||||
|
frames = self.frames[:]
|
||||||
|
self.frames = []
|
||||||
|
if frames:
|
||||||
|
if len(frames) >= self.min_frames:
|
||||||
|
await self.emit(frames)
|
||||||
|
else:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Ignoring flush segment with {len(frames)} frames "
|
||||||
|
f"(< {self.min_frames} minimum)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AudioChunkerAutoProcessor.register("silero", AudioChunkerSileroProcessor)
|
||||||
60
server/reflector/processors/audio_downscale.py
Normal file
60
server/reflector/processors/audio_downscale.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
|
from av.audio.resampler import AudioResampler
|
||||||
|
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
|
||||||
|
|
||||||
|
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
||||||
|
frame_copy = frame.from_ndarray(
|
||||||
|
frame.to_ndarray(),
|
||||||
|
format=frame.format.name,
|
||||||
|
layout=frame.layout.name,
|
||||||
|
)
|
||||||
|
frame_copy.sample_rate = frame.sample_rate
|
||||||
|
frame_copy.pts = frame.pts
|
||||||
|
frame_copy.time_base = frame.time_base
|
||||||
|
return frame_copy
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDownscaleProcessor(Processor):
|
||||||
|
"""
|
||||||
|
Downscale audio frames to 16kHz mono format
|
||||||
|
"""
|
||||||
|
|
||||||
|
INPUT_TYPE = av.AudioFrame
|
||||||
|
OUTPUT_TYPE = av.AudioFrame
|
||||||
|
|
||||||
|
def __init__(self, target_rate: int = 16000, target_layout: str = "mono", **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.target_rate = target_rate
|
||||||
|
self.target_layout = target_layout
|
||||||
|
self.resampler: Optional[AudioResampler] = None
|
||||||
|
self.needs_resampling: Optional[bool] = None
|
||||||
|
|
||||||
|
async def _push(self, data: av.AudioFrame):
|
||||||
|
if self.needs_resampling is None:
|
||||||
|
self.needs_resampling = (
|
||||||
|
data.sample_rate != self.target_rate
|
||||||
|
or data.layout.name != self.target_layout
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.needs_resampling:
|
||||||
|
self.resampler = AudioResampler(
|
||||||
|
format="s16", layout=self.target_layout, rate=self.target_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.needs_resampling or not self.resampler:
|
||||||
|
await self.emit(data)
|
||||||
|
return
|
||||||
|
|
||||||
|
resampled_frames = self.resampler.resample(copy_frame(data))
|
||||||
|
for resampled_frame in resampled_frames:
|
||||||
|
await self.emit(resampled_frame)
|
||||||
|
|
||||||
|
async def _flush(self):
|
||||||
|
if self.needs_resampling and self.resampler:
|
||||||
|
final_frames = self.resampler.resample(None)
|
||||||
|
for frame in final_frames:
|
||||||
|
await self.emit(frame)
|
||||||
@@ -3,24 +3,11 @@ from time import monotonic_ns
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import av
|
import av
|
||||||
from av.audio.resampler import AudioResampler
|
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import AudioFile
|
from reflector.processors.types import AudioFile
|
||||||
|
|
||||||
|
|
||||||
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
|
||||||
frame_copy = frame.from_ndarray(
|
|
||||||
frame.to_ndarray(),
|
|
||||||
format=frame.format.name,
|
|
||||||
layout=frame.layout.name,
|
|
||||||
)
|
|
||||||
frame_copy.sample_rate = frame.sample_rate
|
|
||||||
frame_copy.pts = frame.pts
|
|
||||||
frame_copy.time_base = frame.time_base
|
|
||||||
return frame_copy
|
|
||||||
|
|
||||||
|
|
||||||
class AudioMergeProcessor(Processor):
|
class AudioMergeProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
Merge audio frame into a single file
|
Merge audio frame into a single file
|
||||||
@@ -29,9 +16,8 @@ class AudioMergeProcessor(Processor):
|
|||||||
INPUT_TYPE = list[av.AudioFrame]
|
INPUT_TYPE = list[av.AudioFrame]
|
||||||
OUTPUT_TYPE = AudioFile
|
OUTPUT_TYPE = AudioFile
|
||||||
|
|
||||||
def __init__(self, downsample_to_16k_mono: bool = True, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.downsample_to_16k_mono = downsample_to_16k_mono
|
|
||||||
|
|
||||||
async def _push(self, data: list[av.AudioFrame]):
|
async def _push(self, data: list[av.AudioFrame]):
|
||||||
if not data:
|
if not data:
|
||||||
@@ -39,72 +25,27 @@ class AudioMergeProcessor(Processor):
|
|||||||
|
|
||||||
# get audio information from first frame
|
# get audio information from first frame
|
||||||
frame = data[0]
|
frame = data[0]
|
||||||
original_channels = len(frame.layout.channels)
|
output_channels = len(frame.layout.channels)
|
||||||
original_sample_rate = frame.sample_rate
|
output_sample_rate = frame.sample_rate
|
||||||
original_sample_width = frame.format.bytes
|
output_sample_width = frame.format.bytes
|
||||||
|
|
||||||
# determine if we need processing
|
|
||||||
needs_processing = self.downsample_to_16k_mono and (
|
|
||||||
original_sample_rate != 16000 or original_channels != 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# determine output parameters
|
|
||||||
if self.downsample_to_16k_mono:
|
|
||||||
output_sample_rate = 16000
|
|
||||||
output_channels = 1
|
|
||||||
output_sample_width = 2 # 16-bit = 2 bytes
|
|
||||||
else:
|
|
||||||
output_sample_rate = original_sample_rate
|
|
||||||
output_channels = original_channels
|
|
||||||
output_sample_width = original_sample_width
|
|
||||||
|
|
||||||
# create audio file
|
# create audio file
|
||||||
uu = uuid4().hex
|
uu = uuid4().hex
|
||||||
fd = io.BytesIO()
|
fd = io.BytesIO()
|
||||||
|
|
||||||
if needs_processing:
|
# Use PyAV to write frames
|
||||||
# Process with PyAV resampler
|
out_container = av.open(fd, "w", format="wav")
|
||||||
out_container = av.open(fd, "w", format="wav")
|
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
||||||
out_stream = out_container.add_stream("pcm_s16le", rate=16000)
|
out_stream.layout = frame.layout.name
|
||||||
out_stream.layout = "mono"
|
|
||||||
|
|
||||||
# Create resampler if needed
|
for frame in data:
|
||||||
resampler = None
|
for packet in out_stream.encode(frame):
|
||||||
if original_sample_rate != 16000 or original_channels != 1:
|
|
||||||
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
|
|
||||||
|
|
||||||
for frame in data:
|
|
||||||
if resampler:
|
|
||||||
# Resample and convert to mono
|
|
||||||
# XXX for an unknown reason, if we don't use a copy of the frame, we get
|
|
||||||
# Invalid Argumment from resample. Debugging indicate that when a previous processor
|
|
||||||
# already used the frame (like AudioFileWriter), it make it invalid argument here.
|
|
||||||
resampled_frames = resampler.resample(copy_frame(frame))
|
|
||||||
for resampled_frame in resampled_frames:
|
|
||||||
for packet in out_stream.encode(resampled_frame):
|
|
||||||
out_container.mux(packet)
|
|
||||||
else:
|
|
||||||
# Direct encoding without resampling
|
|
||||||
for packet in out_stream.encode(frame):
|
|
||||||
out_container.mux(packet)
|
|
||||||
|
|
||||||
# Flush the encoder
|
|
||||||
for packet in out_stream.encode(None):
|
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
out_container.close()
|
|
||||||
else:
|
|
||||||
# Use PyAV for original frames (no processing needed)
|
|
||||||
out_container = av.open(fd, "w", format="wav")
|
|
||||||
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
|
||||||
out_stream.layout = "mono" if output_channels == 1 else frame.layout
|
|
||||||
|
|
||||||
for frame in data:
|
# Flush the encoder
|
||||||
for packet in out_stream.encode(frame):
|
for packet in out_stream.encode(None):
|
||||||
out_container.mux(packet)
|
out_container.mux(packet)
|
||||||
|
out_container.close()
|
||||||
for packet in out_stream.encode(None):
|
|
||||||
out_container.mux(packet)
|
|
||||||
out_container.close()
|
|
||||||
|
|
||||||
fd.seek(0)
|
fd.seek(0)
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,6 @@ API will be a POST request to TRANSCRIPT_URL:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
@@ -25,7 +22,9 @@ from reflector.settings import settings
|
|||||||
|
|
||||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, modal_api_key: str | None = None, batch_enabled: bool = True, **kwargs
|
self,
|
||||||
|
modal_api_key: str | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not settings.TRANSCRIPT_URL:
|
if not settings.TRANSCRIPT_URL:
|
||||||
@@ -35,126 +34,6 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
||||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||||
self.modal_api_key = modal_api_key
|
self.modal_api_key = modal_api_key
|
||||||
self.max_batch_duration = 10.0
|
|
||||||
self.max_batch_files = 15
|
|
||||||
self.batch_enabled = batch_enabled
|
|
||||||
self.pending_files: List[AudioFile] = [] # Files waiting to be processed
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _calculate_duration(cls, audio_file: AudioFile) -> float:
|
|
||||||
"""Calculate audio duration in seconds from AudioFile metadata"""
|
|
||||||
# Duration = total_samples / sample_rate
|
|
||||||
# We need to estimate total samples from the file data
|
|
||||||
import wave
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Try to read as WAV file to get duration
|
|
||||||
audio_file.fd.seek(0)
|
|
||||||
with wave.open(audio_file.fd, "rb") as wav_file:
|
|
||||||
frames = wav_file.getnframes()
|
|
||||||
sample_rate = wav_file.getframerate()
|
|
||||||
duration = frames / sample_rate
|
|
||||||
return duration
|
|
||||||
except Exception:
|
|
||||||
# Fallback: estimate from file size and audio parameters
|
|
||||||
audio_file.fd.seek(0, 2) # Seek to end
|
|
||||||
file_size = audio_file.fd.tell()
|
|
||||||
audio_file.fd.seek(0) # Reset to beginning
|
|
||||||
|
|
||||||
# Estimate: file_size / (sample_rate * channels * sample_width)
|
|
||||||
bytes_per_second = (
|
|
||||||
audio_file.sample_rate
|
|
||||||
* audio_file.channels
|
|
||||||
* (audio_file.sample_width // 8)
|
|
||||||
)
|
|
||||||
estimated_duration = (
|
|
||||||
file_size / bytes_per_second if bytes_per_second > 0 else 0
|
|
||||||
)
|
|
||||||
return max(0, estimated_duration)
|
|
||||||
|
|
||||||
def _create_batches(self, audio_files: List[AudioFile]) -> List[List[AudioFile]]:
|
|
||||||
"""Group audio files into batches with maximum 30s total duration"""
|
|
||||||
batches = []
|
|
||||||
current_batch = []
|
|
||||||
current_duration = 0.0
|
|
||||||
|
|
||||||
for audio_file in audio_files:
|
|
||||||
duration = self._calculate_duration(audio_file)
|
|
||||||
|
|
||||||
# If adding this file exceeds max duration, start a new batch
|
|
||||||
if current_duration + duration > self.max_batch_duration and current_batch:
|
|
||||||
batches.append(current_batch)
|
|
||||||
current_batch = [audio_file]
|
|
||||||
current_duration = duration
|
|
||||||
else:
|
|
||||||
current_batch.append(audio_file)
|
|
||||||
current_duration += duration
|
|
||||||
|
|
||||||
# Add the last batch if not empty
|
|
||||||
if current_batch:
|
|
||||||
batches.append(current_batch)
|
|
||||||
|
|
||||||
return batches
|
|
||||||
|
|
||||||
async def _transcript_batch(self, audio_files: List[AudioFile]) -> List[Transcript]:
|
|
||||||
"""Transcribe a batch of audio files using the parakeet backend"""
|
|
||||||
if not audio_files:
|
|
||||||
return []
|
|
||||||
|
|
||||||
self.logger.debug(f"Batch transcribing {len(audio_files)} files")
|
|
||||||
|
|
||||||
# Prepare form data for batch request
|
|
||||||
data = aiohttp.FormData()
|
|
||||||
data.add_field("language", self.get_pref("audio:source_language", "en"))
|
|
||||||
data.add_field("batch", "true")
|
|
||||||
|
|
||||||
for i, audio_file in enumerate(audio_files):
|
|
||||||
audio_file.fd.seek(0)
|
|
||||||
data.add_field(
|
|
||||||
"files",
|
|
||||||
audio_file.fd,
|
|
||||||
filename=f"{audio_file.name}",
|
|
||||||
content_type="audio/wav",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make batch request
|
|
||||||
headers = {"Authorization": f"Bearer {self.modal_api_key}"}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
|
||||||
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
|
||||||
) as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.transcript_url}/audio/transcriptions",
|
|
||||||
data=data,
|
|
||||||
headers=headers,
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
error_text = await response.text()
|
|
||||||
raise Exception(
|
|
||||||
f"Batch transcription failed: {response.status} {error_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
|
|
||||||
# Process batch results
|
|
||||||
transcripts = []
|
|
||||||
results = result.get("results", [])
|
|
||||||
|
|
||||||
for i, (audio_file, file_result) in enumerate(zip(audio_files, results)):
|
|
||||||
transcript = Transcript(
|
|
||||||
words=[
|
|
||||||
Word(
|
|
||||||
text=word_info["word"],
|
|
||||||
start=word_info["start"],
|
|
||||||
end=word_info["end"],
|
|
||||||
)
|
|
||||||
for word_info in file_result.get("words", [])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
transcript.add_offset(audio_file.timestamp)
|
|
||||||
transcripts.append(transcript)
|
|
||||||
|
|
||||||
return transcripts
|
|
||||||
|
|
||||||
async def _transcript(self, data: AudioFile):
|
async def _transcript(self, data: AudioFile):
|
||||||
async with AsyncOpenAI(
|
async with AsyncOpenAI(
|
||||||
@@ -187,96 +66,5 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
|
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
async def transcript_multiple(
|
|
||||||
self, audio_files: List[AudioFile]
|
|
||||||
) -> List[Transcript]:
|
|
||||||
"""Transcribe multiple audio files using batching"""
|
|
||||||
if len(audio_files) == 1:
|
|
||||||
# Single file, use existing method
|
|
||||||
return [await self._transcript(audio_files[0])]
|
|
||||||
|
|
||||||
# Create batches with max 30s duration each
|
|
||||||
batches = self._create_batches(audio_files)
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Processing {len(audio_files)} files in {len(batches)} batches"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process all batches concurrently
|
|
||||||
all_transcripts = []
|
|
||||||
|
|
||||||
for batch in batches:
|
|
||||||
batch_transcripts = await self._transcript_batch(batch)
|
|
||||||
all_transcripts.extend(batch_transcripts)
|
|
||||||
|
|
||||||
return all_transcripts
|
|
||||||
|
|
||||||
async def _push(self, data: AudioFile):
|
|
||||||
"""Override _push to support batching"""
|
|
||||||
if not self.batch_enabled:
|
|
||||||
# Use parent implementation for single file processing
|
|
||||||
return await super()._push(data)
|
|
||||||
|
|
||||||
# Add file to pending batch
|
|
||||||
self.pending_files.append(data)
|
|
||||||
self.logger.debug(
|
|
||||||
f"Added file to batch: {data.name}, batch size: {len(self.pending_files)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate total duration of pending files
|
|
||||||
total_duration = sum(self._calculate_duration(f) for f in self.pending_files)
|
|
||||||
|
|
||||||
# Process batch if it reaches max duration or has multiple files ready for optimization
|
|
||||||
should_process_batch = (
|
|
||||||
total_duration >= self.max_batch_duration
|
|
||||||
or len(self.pending_files) >= self.max_batch_files
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_process_batch:
|
|
||||||
await self._process_pending_batch()
|
|
||||||
|
|
||||||
async def _process_pending_batch(self):
|
|
||||||
"""Process all pending files as batches"""
|
|
||||||
if not self.pending_files:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.logger.debug(f"Processing batch of {len(self.pending_files)} files")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create batches respecting duration limit
|
|
||||||
batches = self._create_batches(self.pending_files)
|
|
||||||
|
|
||||||
# Process each batch
|
|
||||||
for batch in batches:
|
|
||||||
self.m_transcript_call.inc()
|
|
||||||
try:
|
|
||||||
with self.m_transcript.time():
|
|
||||||
# Use batch transcription
|
|
||||||
transcripts = await self._transcript_batch(batch)
|
|
||||||
|
|
||||||
self.m_transcript_success.inc()
|
|
||||||
|
|
||||||
# Emit each transcript
|
|
||||||
for transcript in transcripts:
|
|
||||||
if transcript:
|
|
||||||
await self.emit(transcript)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
self.m_transcript_failure.inc()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# Release audio files
|
|
||||||
for audio_file in batch:
|
|
||||||
audio_file.release()
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clear pending files
|
|
||||||
self.pending_files.clear()
|
|
||||||
|
|
||||||
async def _flush(self):
|
|
||||||
"""Process any remaining files when flushing"""
|
|
||||||
await self._process_pending_batch()
|
|
||||||
await super()._flush()
|
|
||||||
|
|
||||||
|
|
||||||
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ class Settings(BaseSettings):
|
|||||||
# local data directory
|
# local data directory
|
||||||
DATA_DIR: str = "./data"
|
DATA_DIR: str = "./data"
|
||||||
|
|
||||||
|
# Audio Chunking
|
||||||
|
# backends: silero, frames
|
||||||
|
AUDIO_CHUNKER_BACKEND: str = "frames"
|
||||||
|
|
||||||
# Audio Transcription
|
# Audio Transcription
|
||||||
# backends: whisper, modal
|
# backends: whisper, modal
|
||||||
TRANSCRIPT_BACKEND: str = "whisper"
|
TRANSCRIPT_BACKEND: str = "whisper"
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ import av
|
|||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerAutoProcessor,
|
||||||
|
AudioDownscaleProcessor,
|
||||||
AudioFileWriterProcessor,
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
@@ -95,7 +96,8 @@ async def process_audio_file(
|
|||||||
|
|
||||||
# Add the rest of the processors
|
# Add the rest of the processors
|
||||||
processors += [
|
processors += [
|
||||||
AudioChunkerProcessor(),
|
AudioDownscaleProcessor(),
|
||||||
|
AudioChunkerAutoProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
@@ -322,7 +324,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Ignore internal processors
|
# Ignore internal processors
|
||||||
if processor in (
|
if processor in (
|
||||||
"AudioChunkerProcessor",
|
"AudioDownscaleProcessor",
|
||||||
|
"AudioChunkerAutoProcessor",
|
||||||
"AudioMergeProcessor",
|
"AudioMergeProcessor",
|
||||||
"AudioFileWriterProcessor",
|
"AudioFileWriterProcessor",
|
||||||
"TopicCollectorProcessor",
|
"TopicCollectorProcessor",
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ import av
|
|||||||
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerAutoProcessor,
|
||||||
|
AudioDownscaleProcessor,
|
||||||
AudioFileWriterProcessor,
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
@@ -96,7 +97,8 @@ async def process_audio_file_with_diarization(
|
|||||||
|
|
||||||
# Add the rest of the processors
|
# Add the rest of the processors
|
||||||
processors += [
|
processors += [
|
||||||
AudioChunkerProcessor(),
|
AudioDownscaleProcessor(),
|
||||||
|
AudioChunkerAutoProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
]
|
]
|
||||||
@@ -276,7 +278,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Ignore internal processors
|
# Ignore internal processors
|
||||||
if processor in (
|
if processor in (
|
||||||
"AudioChunkerProcessor",
|
"AudioDownscaleProcessor",
|
||||||
|
"AudioChunkerAutoProcessor",
|
||||||
"AudioMergeProcessor",
|
"AudioMergeProcessor",
|
||||||
"AudioFileWriterProcessor",
|
"AudioFileWriterProcessor",
|
||||||
"TopicCollectorProcessor",
|
"TopicCollectorProcessor",
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ async def run_single_processor(args):
|
|||||||
async def event_callback(event: PipelineEvent):
|
async def event_callback(event: PipelineEvent):
|
||||||
processor = event.processor
|
processor = event.processor
|
||||||
# ignore some processor
|
# ignore some processor
|
||||||
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
|
if processor in ("AudioChunkerAutoProcessor", "AudioMergeProcessor"):
|
||||||
return
|
return
|
||||||
print(f"Event: {event}")
|
print(f"Event: {event}")
|
||||||
if output_fd:
|
if output_fd:
|
||||||
|
|||||||
Reference in New Issue
Block a user