mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-24 22:25:19 +00:00
feat: 3-mode selfhosted refactoring (--gpu, --cpu, --hosted) + audio token auth fallback (#896)
* fix: local processing instead of http server for cpu * add fallback token if service worker doesnt work * chore: rename processors to keep processor pattern up to date and allow other processors to be createed and used with env vars
This commit is contained in:
committed by
GitHub
parent
4235ab4293
commit
a682846645
@@ -14,6 +14,7 @@ current_user = auth_module.current_user
|
||||
current_user_optional = auth_module.current_user_optional
|
||||
parse_ws_bearer_token = auth_module.parse_ws_bearer_token
|
||||
current_user_ws_optional = auth_module.current_user_ws_optional
|
||||
verify_raw_token = auth_module.verify_raw_token
|
||||
|
||||
# Optional router (e.g. for /auth/login in password backend)
|
||||
router = getattr(auth_module, "router", None)
|
||||
|
||||
@@ -144,3 +144,8 @@ async def current_user_ws_optional(websocket: "WebSocket") -> Optional[UserInfo]
|
||||
if not token:
|
||||
return None
|
||||
return await _authenticate_user(token, None, JWTAuth())
|
||||
|
||||
|
||||
async def verify_raw_token(token: str) -> Optional[UserInfo]:
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return await _authenticate_user(token, None, JWTAuth())
|
||||
|
||||
@@ -27,3 +27,8 @@ def parse_ws_bearer_token(websocket):
|
||||
|
||||
async def current_user_ws_optional(websocket):
|
||||
return None
|
||||
|
||||
|
||||
async def verify_raw_token(token):
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return None
|
||||
|
||||
@@ -168,6 +168,11 @@ async def current_user_ws_optional(websocket: "WebSocket") -> Optional[UserInfo]
|
||||
return await _authenticate_user(token, None)
|
||||
|
||||
|
||||
async def verify_raw_token(token: str) -> Optional[UserInfo]:
|
||||
"""Verify a raw JWT token string (used for query-param auth fallback)."""
|
||||
return await _authenticate_user(token, None)
|
||||
|
||||
|
||||
# --- Login router ---
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_downscale import AudioDownscaleProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_padding import AudioPaddingProcessor # noqa: F401
|
||||
from .audio_padding_auto import AudioPaddingAutoProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
|
||||
86
server/reflector/processors/_audio_download.py
Normal file
86
server/reflector/processors/_audio_download.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Shared audio download utility for local processors.
|
||||
|
||||
Downloads audio from a URL to a temporary file for in-process ML inference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
S3_TIMEOUT = 60
|
||||
|
||||
|
||||
async def download_audio_to_temp(url: str) -> Path:
|
||||
"""Download audio from URL to a temporary file.
|
||||
|
||||
The caller is responsible for deleting the temp file after use.
|
||||
|
||||
Args:
|
||||
url: Presigned URL or public URL to download audio from.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded temporary file.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _download_blocking, url)
|
||||
|
||||
|
||||
def _download_blocking(url: str) -> Path:
|
||||
"""Blocking download implementation."""
|
||||
log = logger.bind(url=url[:80])
|
||||
log.info("Downloading audio to temp file")
|
||||
|
||||
response = requests.get(url, stream=True, timeout=S3_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
# Determine extension from content-type or URL
|
||||
ext = _detect_extension(url, response.headers.get("content-type", ""))
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=ext)
|
||||
try:
|
||||
total_bytes = 0
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
log.info("Audio downloaded", bytes=total_bytes, path=tmp_path)
|
||||
return Path(tmp_path)
|
||||
except Exception:
|
||||
# Clean up on failure
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _detect_extension(url: str, content_type: str) -> str:
|
||||
"""Detect audio file extension from URL or content-type."""
|
||||
# Try URL path first
|
||||
path = url.split("?")[0] # Strip query params
|
||||
for ext in (".wav", ".mp3", ".mp4", ".m4a", ".webm", ".ogg", ".flac"):
|
||||
if path.lower().endswith(ext):
|
||||
return ext
|
||||
|
||||
# Try content-type
|
||||
ct_map = {
|
||||
"audio/wav": ".wav",
|
||||
"audio/x-wav": ".wav",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/mp4": ".m4a",
|
||||
"audio/webm": ".webm",
|
||||
"audio/ogg": ".ogg",
|
||||
"audio/flac": ".flac",
|
||||
}
|
||||
for ct, ext in ct_map.items():
|
||||
if ct in content_type.lower():
|
||||
return ext
|
||||
|
||||
return ".audio"
|
||||
76
server/reflector/processors/_marian_translator_service.py
Normal file
76
server/reflector/processors/_marian_translator_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
MarianMT translation service.
|
||||
|
||||
Singleton service that loads HuggingFace MarianMT translation models
|
||||
and reuses them across all MarianMT translator processor instances.
|
||||
|
||||
Ported from gpu/self_hosted/app/services/translator.py for in-process use.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarianTranslatorService:
|
||||
"""MarianMT text translation service for in-process use."""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._current_pair = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self, source_language: str = "en", target_language: str = "fr"):
|
||||
"""Load the translation model for a specific language pair."""
|
||||
model_name = self._resolve_model_name(source_language, target_language)
|
||||
logger.info(
|
||||
"Loading MarianMT model: %s (%s -> %s)",
|
||||
model_name,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
||||
self._current_pair = (source_language.lower(), target_language.lower())
|
||||
|
||||
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
||||
"""Resolve language pair to MarianMT model name."""
|
||||
pair = (src.lower(), tgt.lower())
|
||||
mapping = {
|
||||
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
||||
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
||||
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
||||
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
||||
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
||||
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
||||
}
|
||||
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
||||
|
||||
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
||||
"""Translate text between languages.
|
||||
|
||||
Args:
|
||||
text: Text to translate.
|
||||
source_language: Source language code (e.g. "en").
|
||||
target_language: Target language code (e.g. "fr").
|
||||
|
||||
Returns:
|
||||
dict with "text" key containing {source_language: original, target_language: translated}.
|
||||
"""
|
||||
pair = (source_language.lower(), target_language.lower())
|
||||
if self._pipeline is None or self._current_pair != pair:
|
||||
self.load(source_language, target_language)
|
||||
with self._lock:
|
||||
results = self._pipeline(
|
||||
text, src_lang=source_language, tgt_lang=target_language
|
||||
)
|
||||
translated = results[0]["translation_text"] if results else ""
|
||||
return {"text": {source_language: text, target_language: translated}}
|
||||
|
||||
|
||||
# Module-level singleton — shared across all MarianMT translator processors
|
||||
translator_service = MarianTranslatorService()
|
||||
133
server/reflector/processors/_pyannote_diarization_service.py
Normal file
133
server/reflector/processors/_pyannote_diarization_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Pyannote diarization service using pyannote.audio.
|
||||
|
||||
Singleton service that loads the pyannote speaker diarization model once
|
||||
and reuses it across all pyannote diarization processor instances.
|
||||
|
||||
Ported from gpu/self_hosted/app/services/diarizer.py for in-process use.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import tarfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from urllib.request import urlopen
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import yaml
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
S3_BUNDLE_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz"
|
||||
BUNDLE_CACHE_DIR = Path.home() / ".cache" / "pyannote-bundle"
|
||||
|
||||
|
||||
def _ensure_model(cache_dir: Path) -> str:
|
||||
"""Download and extract S3 model bundle if not cached."""
|
||||
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
|
||||
config_path = model_dir / "config.yaml"
|
||||
|
||||
if config_path.exists():
|
||||
logger.info("Using cached model bundle at %s", model_dir)
|
||||
return str(model_dir)
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
tarball_path = cache_dir / "model.tar.gz"
|
||||
|
||||
logger.info("Downloading model bundle from %s", S3_BUNDLE_URL)
|
||||
with urlopen(S3_BUNDLE_URL) as response, open(tarball_path, "wb") as f:
|
||||
while chunk := response.read(8192):
|
||||
f.write(chunk)
|
||||
|
||||
logger.info("Extracting model bundle")
|
||||
with tarfile.open(tarball_path, "r:gz") as tar:
|
||||
tar.extractall(path=cache_dir, filter="data")
|
||||
tarball_path.unlink()
|
||||
|
||||
_patch_config(model_dir, cache_dir)
|
||||
return str(model_dir)
|
||||
|
||||
|
||||
def _patch_config(model_dir: Path, cache_dir: Path) -> None:
|
||||
"""Rewrite config.yaml to reference local pytorch_model.bin paths."""
|
||||
config_path = model_dir / "config.yaml"
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config["pipeline"]["params"]["segmentation"] = str(
|
||||
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
|
||||
)
|
||||
config["pipeline"]["params"]["embedding"] = str(
|
||||
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
|
||||
)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
logger.info("Patched config.yaml with local model paths")
|
||||
|
||||
|
||||
class PyannoteDiarizationService:
|
||||
"""Pyannote speaker diarization service for in-process use."""
|
||||
|
||||
def __init__(self):
|
||||
self._pipeline = None
|
||||
self._device = "cpu"
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self):
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
hf_token = settings.HF_TOKEN
|
||||
|
||||
if hf_token:
|
||||
logger.info("Loading pyannote model from HuggingFace (HF_TOKEN set)")
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
else:
|
||||
logger.info("HF_TOKEN not set — loading model from S3 bundle")
|
||||
model_path = _ensure_model(BUNDLE_CACHE_DIR)
|
||||
config_path = Path(model_path) / "config.yaml"
|
||||
self._pipeline = Pipeline.from_pretrained(str(config_path))
|
||||
|
||||
self._pipeline.to(torch.device(self._device))
|
||||
|
||||
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
||||
"""Run speaker diarization on an audio file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the audio file.
|
||||
timestamp: Offset to add to all segment timestamps.
|
||||
|
||||
Returns:
|
||||
dict with "diarization" key containing list of
|
||||
{"start": float, "end": float, "speaker": int} segments.
|
||||
"""
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
with self._lock:
|
||||
diarization = self._pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
segments = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
segments.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:])
|
||||
if speaker and speaker[-2:].isdigit()
|
||||
else 0,
|
||||
}
|
||||
)
|
||||
return {"diarization": segments}
|
||||
|
||||
|
||||
# Module-level singleton — shared across all pyannote diarization processors
|
||||
diarization_service = PyannoteDiarizationService()
|
||||
37
server/reflector/processors/audio_diarization_pyannote.py
Normal file
37
server/reflector/processors/audio_diarization_pyannote.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Pyannote audio diarization processor using pyannote.audio in-process.
|
||||
|
||||
Downloads audio from URL, runs pyannote diarization locally,
|
||||
and returns speaker segments. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors._pyannote_diarization_service import diarization_service
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
|
||||
|
||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
"""Run pyannote diarization on audio from URL."""
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, diarization_service.diarize_file, str(tmp_path)
|
||||
)
|
||||
return result["diarization"]
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||
23
server/reflector/processors/audio_padding.py
Normal file
23
server/reflector/processors/audio_padding.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Base class for audio padding processors.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class AudioPaddingProcessor:
|
||||
"""Base class for audio padding processors."""
|
||||
|
||||
async def pad_track(
|
||||
self,
|
||||
track_url: str,
|
||||
output_url: str,
|
||||
start_time_seconds: float,
|
||||
track_index: int,
|
||||
) -> PaddingResponse:
|
||||
raise NotImplementedError
|
||||
@@ -1,9 +1,10 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioPaddingAutoProcessor:
|
||||
class AudioPaddingAutoProcessor(AudioPaddingProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,19 +6,14 @@ import asyncio
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.hatchet.constants import TIMEOUT_AUDIO
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
||||
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
||||
|
||||
|
||||
class PaddingResponse(BaseModel):
|
||||
size: int
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class AudioPaddingModalProcessor:
|
||||
class AudioPaddingModalProcessor(AudioPaddingProcessor):
|
||||
"""Audio padding processor using Modal.com CPU backend via HTTP."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Local audio padding processor using PyAV.
|
||||
PyAV audio padding processor.
|
||||
|
||||
Pads audio tracks with silence directly in-process (no HTTP).
|
||||
Reuses the shared PyAV utilities from reflector.utils.audio_padding.
|
||||
@@ -12,15 +12,15 @@ import tempfile
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors.audio_padding import AudioPaddingProcessor, PaddingResponse
|
||||
from reflector.processors.audio_padding_auto import AudioPaddingAutoProcessor
|
||||
from reflector.processors.audio_padding_modal import PaddingResponse
|
||||
from reflector.utils.audio_padding import apply_audio_padding_to_file
|
||||
|
||||
S3_TIMEOUT = 60
|
||||
|
||||
|
||||
class AudioPaddingLocalProcessor:
|
||||
"""Audio padding processor using local PyAV (no HTTP backend)."""
|
||||
class AudioPaddingPyavProcessor(AudioPaddingProcessor):
|
||||
"""Audio padding processor using PyAV (no HTTP backend)."""
|
||||
|
||||
async def pad_track(
|
||||
self,
|
||||
@@ -29,7 +29,7 @@ class AudioPaddingLocalProcessor:
|
||||
start_time_seconds: float,
|
||||
track_index: int,
|
||||
) -> PaddingResponse:
|
||||
"""Pad audio track with silence locally via PyAV.
|
||||
"""Pad audio track with silence via PyAV.
|
||||
|
||||
Args:
|
||||
track_url: Presigned GET URL for source audio track
|
||||
@@ -130,4 +130,4 @@ class AudioPaddingLocalProcessor:
|
||||
log.warning("Failed to cleanup temp directory", error=str(e))
|
||||
|
||||
|
||||
AudioPaddingAutoProcessor.register("local", AudioPaddingLocalProcessor)
|
||||
AudioPaddingAutoProcessor.register("pyav", AudioPaddingPyavProcessor)
|
||||
@@ -3,13 +3,17 @@ from faster_whisper import WhisperModel
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = WhisperModel(
|
||||
"tiny", device="cpu", compute_type="float32", num_workers=12
|
||||
settings.WHISPER_CHUNK_MODEL,
|
||||
device="cpu",
|
||||
compute_type="float32",
|
||||
num_workers=12,
|
||||
)
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
|
||||
39
server/reflector/processors/file_diarization_pyannote.py
Normal file
39
server/reflector/processors/file_diarization_pyannote.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Pyannote file diarization processor using pyannote.audio in-process.
|
||||
|
||||
Downloads audio from URL, runs pyannote diarization locally,
|
||||
and returns speaker segments. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors._pyannote_diarization_service import diarization_service
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
|
||||
|
||||
class FileDiarizationPyannoteProcessor(FileDiarizationProcessor):
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
"""Run pyannote diarization on file from URL."""
|
||||
self.logger.info(f"Starting pyannote diarization from {data.audio_url}")
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, diarization_service.diarize_file, str(tmp_path)
|
||||
)
|
||||
return FileDiarizationOutput(diarization=result["diarization"])
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor)
|
||||
275
server/reflector/processors/file_transcript_whisper.py
Normal file
275
server/reflector/processors/file_transcript_whisper.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Local file transcription processor using faster-whisper with Silero VAD pipeline.
|
||||
|
||||
Downloads audio from URL, segments it using Silero VAD, transcribes each
|
||||
segment with faster-whisper, and merges results. No HTTP backend needed.
|
||||
|
||||
VAD pipeline ported from gpu/self_hosted/app/services/transcriber.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
from reflector.processors.file_transcript import (
|
||||
FileTranscriptInput,
|
||||
FileTranscriptProcessor,
|
||||
)
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.types import Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
VAD_CONFIG = {
|
||||
"batch_max_duration": 30.0,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
|
||||
class FileTranscriptWhisperProcessor(FileTranscriptProcessor):
|
||||
"""Transcribe complete audio files using local faster-whisper with VAD."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _ensure_model(self):
|
||||
"""Lazy-load the whisper model on first use."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
import faster_whisper
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
model_name = settings.WHISPER_FILE_MODEL
|
||||
|
||||
self.logger.info(
|
||||
"Loading whisper model",
|
||||
model=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
self._model = faster_whisper.WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
"""Download file, run VAD segmentation, transcribe each segment."""
|
||||
tmp_path = await download_audio_to_temp(data.audio_url)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
self._transcribe_file_blocking,
|
||||
str(tmp_path),
|
||||
data.language,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _transcribe_file_blocking(self, file_path: str, language: str) -> Transcript:
|
||||
"""Blocking transcription with VAD pipeline."""
|
||||
self._ensure_model()
|
||||
|
||||
audio_array = _load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
||||
|
||||
# VAD segmentation → batch merging
|
||||
merged_batches: list[tuple[float, float]] = []
|
||||
batch_start = None
|
||||
batch_end = None
|
||||
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||
|
||||
for seg_start, seg_end in _vad_segments(audio_array):
|
||||
if batch_start is None:
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
continue
|
||||
if seg_end - batch_start <= max_duration:
|
||||
batch_end = seg_end
|
||||
else:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
batch_start, batch_end = seg_start, seg_end
|
||||
|
||||
if batch_start is not None and batch_end is not None:
|
||||
merged_batches.append((batch_start, batch_end))
|
||||
|
||||
# If no speech detected, try transcribing the whole file
|
||||
if not merged_batches:
|
||||
return self._transcribe_whole_file(file_path, language)
|
||||
|
||||
# Transcribe each batch
|
||||
all_words = []
|
||||
for start_time, end_time in merged_batches:
|
||||
s_idx = int(start_time * SAMPLE_RATE)
|
||||
e_idx = int(end_time * SAMPLE_RATE)
|
||||
segment = audio_array[s_idx:e_idx]
|
||||
segment = _pad_audio(segment, SAMPLE_RATE)
|
||||
|
||||
with self._lock:
|
||||
segments, _ = self._model.transcribe(
|
||||
segment,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
|
||||
for seg in segments:
|
||||
for w in seg.words:
|
||||
all_words.append(
|
||||
{
|
||||
"word": w.word,
|
||||
"start": round(float(w.start) + start_time, 2),
|
||||
"end": round(float(w.end) + start_time, 2),
|
||||
}
|
||||
)
|
||||
|
||||
all_words = _enforce_word_timing_constraints(all_words)
|
||||
|
||||
words = [
|
||||
Word(text=w["word"], start=w["start"], end=w["end"]) for w in all_words
|
||||
]
|
||||
words.sort(key=lambda w: w.start)
|
||||
return Transcript(words=words)
|
||||
|
||||
def _transcribe_whole_file(self, file_path: str, language: str) -> Transcript:
|
||||
"""Fallback: transcribe entire file without VAD segmentation."""
|
||||
with self._lock:
|
||||
segments, _ = self._model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
segments = list(segments)
|
||||
|
||||
words = []
|
||||
for seg in segments:
|
||||
for w in seg.words:
|
||||
words.append(
|
||||
Word(
|
||||
text=w.word,
|
||||
start=round(float(w.start), 2),
|
||||
end=round(float(w.end), 2),
|
||||
)
|
||||
)
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
# --- VAD helpers (ported from gpu/self_hosted/app/services/transcriber.py) ---
|
||||
# IMPORTANT: This VAD segment logic is duplicated for deployment isolation.
|
||||
# If you modify this, consider updating the GPU service copy as well:
|
||||
# - gpu/self_hosted/app/services/transcriber.py
|
||||
# - gpu/modal_deployments/reflector_transcriber.py
|
||||
# - gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||
|
||||
|
||||
def _load_audio_via_ffmpeg(
|
||||
input_path: str, sample_rate: int = SAMPLE_RATE
|
||||
) -> np.ndarray:
|
||||
"""Load audio file via ffmpeg, converting to mono float32 at target sample rate."""
|
||||
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
||||
cmd = [
|
||||
ffmpeg_bin,
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"1",
|
||||
"-i",
|
||||
input_path,
|
||||
"-f",
|
||||
"f32le",
|
||||
"-acodec",
|
||||
"pcm_f32le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(sample_rate),
|
||||
"pipe:1",
|
||||
]
|
||||
proc = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
)
|
||||
return np.frombuffer(proc.stdout, dtype=np.float32)
|
||||
|
||||
|
||||
def _vad_segments(
|
||||
audio_array: np.ndarray,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
window_size: int = VAD_CONFIG["window_size"],
|
||||
) -> Generator[tuple[float, float], None, None]:
|
||||
"""Detect speech segments using Silero VAD."""
|
||||
vad_model = load_silero_vad(onnx=False)
|
||||
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
||||
start = None
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(chunk, (0, window_size - len(chunk)), mode="constant")
|
||||
speech = iterator(chunk)
|
||||
if not speech:
|
||||
continue
|
||||
if "start" in speech:
|
||||
start = speech["start"]
|
||||
continue
|
||||
if "end" in speech and start is not None:
|
||||
end = speech["end"]
|
||||
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
||||
start = None
|
||||
|
||||
# Handle case where audio ends while speech is still active
|
||||
if start is not None:
|
||||
audio_duration = len(audio_array) / float(sample_rate)
|
||||
yield (start / float(SAMPLE_RATE), audio_duration)
|
||||
|
||||
iterator.reset_states()
|
||||
|
||||
|
||||
def _pad_audio(audio_array: np.ndarray, sample_rate: int = SAMPLE_RATE) -> np.ndarray:
|
||||
"""Pad short audio with silence for VAD compatibility."""
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
def _enforce_word_timing_constraints(words: list[dict]) -> list[dict]:
|
||||
"""Ensure no word end time exceeds the next word's start time."""
|
||||
if len(words) <= 1:
|
||||
return words
|
||||
enforced: list[dict] = []
|
||||
for i, word in enumerate(words):
|
||||
current = dict(word)
|
||||
if i < len(words) - 1:
|
||||
next_start = words[i + 1]["start"]
|
||||
if current["end"] > next_start:
|
||||
current["end"] = next_start
|
||||
enforced.append(current)
|
||||
return enforced
|
||||
|
||||
|
||||
FileTranscriptAutoProcessor.register("whisper", FileTranscriptWhisperProcessor)
|
||||
50
server/reflector/processors/transcript_translator_marian.py
Normal file
50
server/reflector/processors/transcript_translator_marian.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
MarianMT transcript translator processor using HuggingFace MarianMT in-process.
|
||||
|
||||
Translates transcript text using HuggingFace MarianMT models
|
||||
locally. No HTTP backend needed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from reflector.processors._marian_translator_service import translator_service
|
||||
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
|
||||
from reflector.processors.transcript_translator_auto import (
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.types import TranslationLanguages
|
||||
|
||||
|
||||
class TranscriptTranslatorMarianProcessor(TranscriptTranslatorProcessor):
|
||||
"""Translate transcript text using MarianMT models."""
|
||||
|
||||
async def _translate(self, text: str) -> str | None:
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
|
||||
languages = TranslationLanguages()
|
||||
assert languages.is_supported(target_language)
|
||||
|
||||
self.logger.debug(f"MarianMT translate {text=}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
translator_service.translate,
|
||||
text,
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
|
||||
if target_language in result["text"]:
|
||||
translation = result["text"][target_language]
|
||||
else:
|
||||
translation = None
|
||||
|
||||
self.logger.debug(f"Translation result: {text=}, {translation=}")
|
||||
return translation
|
||||
|
||||
|
||||
TranscriptTranslatorAutoProcessor.register(
|
||||
"marian", TranscriptTranslatorMarianProcessor
|
||||
)
|
||||
@@ -40,11 +40,19 @@ class Settings(BaseSettings):
|
||||
# backends: silero, frames
|
||||
AUDIO_CHUNKER_BACKEND: str = "frames"
|
||||
|
||||
# HuggingFace token for gated models (pyannote diarization in --cpu mode)
|
||||
HF_TOKEN: str | None = None
|
||||
|
||||
# Audio Transcription
|
||||
# backends:
|
||||
# - whisper: in-process model loading (no HTTP, runs in same process)
|
||||
# - modal: HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
|
||||
# Whisper model sizes for local transcription
|
||||
# Options: "tiny", "base", "small", "medium", "large-v2"
|
||||
WHISPER_CHUNK_MODEL: str = "tiny"
|
||||
WHISPER_FILE_MODEL: str = "tiny"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||
@@ -100,7 +108,7 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# Diarization
|
||||
# backend: modal — HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
# backends: modal — HTTP API client, pyannote — in-process pyannote.audio
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
@@ -111,9 +119,9 @@ class Settings(BaseSettings):
|
||||
|
||||
# Audio Padding
|
||||
# backends:
|
||||
# - local: in-process PyAV padding (no HTTP, runs in same process)
|
||||
# - pyav: in-process PyAV padding (no HTTP, runs in same process)
|
||||
# - modal: HTTP API client (works with Modal.com OR self-hosted gpu/self_hosted/)
|
||||
PADDING_BACKEND: str = "local"
|
||||
PADDING_BACKEND: str = "pyav"
|
||||
PADDING_URL: str | None = None
|
||||
PADDING_MODAL_API_KEY: str | None = None
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
import reflector.auth as auth
|
||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.views.transcripts import ALGORITHM
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
|
||||
@@ -36,16 +35,23 @@ async def transcript_get_audio_mp3(
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id and token:
|
||||
unauthorized_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
except jwt.PyJWTError:
|
||||
raise unauthorized_exception
|
||||
token_user = await auth.verify_raw_token(token)
|
||||
except Exception:
|
||||
token_user = None
|
||||
# Fallback: try as internal HS256 token (created by _generate_local_audio_link)
|
||||
if not token_user:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
user_id = payload.get("sub")
|
||||
except jwt.PyJWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
else:
|
||||
user_id = token_user["sub"]
|
||||
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
|
||||
Reference in New Issue
Block a user