mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-29 18:36:49 +00:00
feat: standalone uses self-hosted GPU service for transcription+diarization
Replace in-process pyannote approach with self-hosted gpu/self_hosted/ service. Same HTTP API as Modal — just TRANSCRIPT_URL/DIARIZATION_URL point to local container. - Add gpu/self_hosted/Dockerfile.cpu (GPU Dockerfile minus NVIDIA CUDA) - Add S3 model bundle fallback in diarizer.py when HF_TOKEN not set - Add gpu service to docker-compose.standalone.yml with compose env overrides - Fix /browse empty in PUBLIC_MODE (search+list queries filtered out roomless transcripts) - Remove audio_diarization_pyannote.py, file_diarization_pyannote.py and tests - Remove pyannote-audio from server local deps
This commit is contained in:
@@ -1,144 +0,0 @@
|
||||
import os
|
||||
import tarfile
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import torch
|
||||
import torchaudio
|
||||
import yaml
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
DEFAULT_MODEL_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz"
|
||||
DEFAULT_CACHE_DIR = "/tmp/pyannote-cache"
|
||||
|
||||
|
||||
class FileDiarizationPyannoteProcessor(FileDiarizationProcessor):
|
||||
"""File diarization using local pyannote.audio pipeline.
|
||||
|
||||
Downloads model bundle from URL (or uses HuggingFace), runs speaker diarization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pyannote_model_url: str = DEFAULT_MODEL_URL,
|
||||
pyannote_model_name: str | None = None,
|
||||
pyannote_auth_token: str | None = None,
|
||||
pyannote_device: str | None = None,
|
||||
pyannote_cache_dir: str = DEFAULT_CACHE_DIR,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
||||
self.device = pyannote_device or (
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
if pyannote_model_name:
|
||||
model_path = pyannote_model_name
|
||||
else:
|
||||
model_path = self._ensure_model(
|
||||
pyannote_model_url, Path(pyannote_cache_dir)
|
||||
)
|
||||
|
||||
self.logger.info("Loading pyannote model", model=model_path, device=self.device)
|
||||
# from_pretrained needs a file path (config.yaml) for local models,
|
||||
# or a HuggingFace repo ID for remote ones
|
||||
config_path = Path(model_path) / "config.yaml"
|
||||
load_path = str(config_path) if config_path.is_file() else model_path
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
load_path, use_auth_token=self.auth_token
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
def _ensure_model(self, model_url: str, cache_dir: Path) -> str:
|
||||
"""Download and extract model bundle if not cached."""
|
||||
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
|
||||
config_path = model_dir / "config.yaml"
|
||||
|
||||
if config_path.exists():
|
||||
self.logger.info("Using cached model", path=str(model_dir))
|
||||
return str(model_dir)
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
tarball_path = cache_dir / "model.tar.gz"
|
||||
|
||||
self.logger.info("Downloading model bundle", url=model_url)
|
||||
with httpx.Client() as client:
|
||||
with client.stream("GET", model_url, follow_redirects=True) as response:
|
||||
response.raise_for_status()
|
||||
with open(tarball_path, "wb") as f:
|
||||
for chunk in response.iter_bytes(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
self.logger.info("Extracting model bundle")
|
||||
with tarfile.open(tarball_path, "r:gz") as tar:
|
||||
tar.extractall(path=cache_dir, filter="data")
|
||||
tarball_path.unlink()
|
||||
|
||||
self._patch_config(model_dir, cache_dir)
|
||||
return str(model_dir)
|
||||
|
||||
def _patch_config(self, model_dir: Path, cache_dir: Path) -> None:
|
||||
"""Rewrite config.yaml to reference local model paths."""
|
||||
config_path = model_dir / "config.yaml"
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config["pipeline"]["params"]["segmentation"] = str(
|
||||
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
|
||||
)
|
||||
config["pipeline"]["params"]["embedding"] = str(
|
||||
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
|
||||
)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
self.logger.info("Patched config.yaml with local model paths")
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput) -> FileDiarizationOutput:
|
||||
self.logger.info("Downloading audio for diarization", audio_url=data.audio_url)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmp:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(data.audio_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
tmp.write(response.content)
|
||||
tmp.flush()
|
||||
|
||||
waveform, sample_rate = torchaudio.load(tmp.name)
|
||||
|
||||
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
||||
diarization = self.diarization_pipeline(audio_input)
|
||||
|
||||
segments: list[DiarizationSegment] = []
|
||||
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
speaker_id = 0
|
||||
if speaker.startswith("SPEAKER_"):
|
||||
try:
|
||||
speaker_id = int(speaker.split("_")[-1])
|
||||
except (ValueError, IndexError):
|
||||
speaker_id = hash(speaker) % 1000
|
||||
|
||||
segments.append(
|
||||
{
|
||||
"start": round(segment.start, 3),
|
||||
"end": round(segment.end, 3),
|
||||
"speaker": speaker_id,
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info("Diarization complete", segment_count=len(segments))
|
||||
return FileDiarizationOutput(diarization=segments)
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor)
|
||||
@@ -1,192 +0,0 @@
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from reflector.processors.file_diarization_pyannote import (
|
||||
FileDiarizationPyannoteProcessor,
|
||||
)
|
||||
|
||||
ORIGINAL_CONFIG = {
|
||||
"version": "3.1.0",
|
||||
"pipeline": {
|
||||
"name": "pyannote.audio.pipelines.SpeakerDiarization",
|
||||
"params": {
|
||||
"clustering": "AgglomerativeClustering",
|
||||
"embedding": "pyannote/wespeaker-voxceleb-resnet34-LM",
|
||||
"embedding_batch_size": 32,
|
||||
"embedding_exclude_overlap": True,
|
||||
"segmentation": "pyannote/segmentation-3.0",
|
||||
"segmentation_batch_size": 32,
|
||||
},
|
||||
},
|
||||
"params": {
|
||||
"clustering": {
|
||||
"method": "centroid",
|
||||
"min_cluster_size": 12,
|
||||
"threshold": 0.7045654963945799,
|
||||
},
|
||||
"segmentation": {"min_duration_off": 0.0},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_model_tarball(tarball_path: Path) -> None:
|
||||
"""Create a fake model tarball matching real structure."""
|
||||
build_dir = tarball_path.parent / "_build"
|
||||
dirs = {
|
||||
"pyannote-speaker-diarization-3.1": {"config.yaml": yaml.dump(ORIGINAL_CONFIG)},
|
||||
"pyannote-segmentation-3.0": {
|
||||
"config.yaml": "model: {}\n",
|
||||
"pytorch_model.bin": b"fake",
|
||||
},
|
||||
"pyannote-wespeaker-voxceleb-resnet34-LM": {
|
||||
"config.yaml": "model: {}\n",
|
||||
"pytorch_model.bin": b"fake",
|
||||
},
|
||||
}
|
||||
for dirname, files in dirs.items():
|
||||
d = build_dir / dirname
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
for fname, content in files.items():
|
||||
p = d / fname
|
||||
if isinstance(content, bytes):
|
||||
p.write_bytes(content)
|
||||
else:
|
||||
p.write_text(content)
|
||||
|
||||
with tarfile.open(tarball_path, "w:gz") as tar:
|
||||
for dirname in dirs:
|
||||
tar.add(build_dir / dirname, arcname=dirname)
|
||||
|
||||
|
||||
def _make_mock_processor() -> MagicMock:
|
||||
proc = MagicMock()
|
||||
proc.logger = MagicMock()
|
||||
return proc
|
||||
|
||||
|
||||
class TestEnsureModel:
|
||||
"""Test model download, extraction, and config patching."""
|
||||
|
||||
def test_extracts_and_patches_config(self, tmp_path: Path) -> None:
|
||||
"""Downloads tarball, extracts, patches config to local paths."""
|
||||
cache_dir = tmp_path / "cache"
|
||||
tarball_path = tmp_path / "model.tar.gz"
|
||||
_make_model_tarball(tarball_path)
|
||||
tarball_bytes = tarball_path.read_bytes()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_bytes.return_value = [tarball_bytes]
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.__enter__ = MagicMock(return_value=mock_response)
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.stream.return_value = mock_response
|
||||
|
||||
proc = _make_mock_processor()
|
||||
proc._patch_config = lambda model_dir, cache_dir: (
|
||||
FileDiarizationPyannoteProcessor._patch_config(proc, model_dir, cache_dir)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_diarization_pyannote.httpx.Client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
result = FileDiarizationPyannoteProcessor._ensure_model(
|
||||
proc, "http://fake/model.tar.gz", cache_dir
|
||||
)
|
||||
|
||||
assert result == str(cache_dir / "pyannote-speaker-diarization-3.1")
|
||||
|
||||
patched_config_path = (
|
||||
cache_dir / "pyannote-speaker-diarization-3.1" / "config.yaml"
|
||||
)
|
||||
with open(patched_config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
assert config["pipeline"]["params"]["segmentation"] == str(
|
||||
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
|
||||
)
|
||||
assert config["pipeline"]["params"]["embedding"] == str(
|
||||
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
|
||||
)
|
||||
# Non-patched fields preserved
|
||||
assert config["pipeline"]["params"]["clustering"] == "AgglomerativeClustering"
|
||||
assert config["params"]["clustering"]["threshold"] == pytest.approx(
|
||||
0.7045654963945799
|
||||
)
|
||||
|
||||
def test_uses_cache_on_second_call(self, tmp_path: Path) -> None:
|
||||
"""Skips download if model dir already exists."""
|
||||
cache_dir = tmp_path / "cache"
|
||||
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "config.yaml").write_text("cached: true")
|
||||
|
||||
proc = _make_mock_processor()
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_diarization_pyannote.httpx.Client"
|
||||
) as mock_httpx:
|
||||
result = FileDiarizationPyannoteProcessor._ensure_model(
|
||||
proc, "http://fake/model.tar.gz", cache_dir
|
||||
)
|
||||
mock_httpx.assert_not_called()
|
||||
|
||||
assert result == str(model_dir)
|
||||
|
||||
|
||||
class TestDiarizeSegmentParsing:
|
||||
"""Test that pyannote output is correctly converted to DiarizationSegment."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_speaker_segments(self) -> None:
|
||||
proc = _make_mock_processor()
|
||||
|
||||
mock_seg_0 = MagicMock()
|
||||
mock_seg_0.start = 0.123456
|
||||
mock_seg_0.end = 1.789012
|
||||
mock_seg_1 = MagicMock()
|
||||
mock_seg_1.start = 2.0
|
||||
mock_seg_1.end = 3.5
|
||||
|
||||
mock_diarization = MagicMock()
|
||||
mock_diarization.itertracks.return_value = [
|
||||
(mock_seg_0, None, "SPEAKER_00"),
|
||||
(mock_seg_1, None, "SPEAKER_01"),
|
||||
]
|
||||
proc.diarization_pipeline = MagicMock(return_value=mock_diarization)
|
||||
|
||||
mock_input = MagicMock()
|
||||
mock_input.audio_url = "http://fake/audio.mp3"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.content = b"fake audio"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.httpx.AsyncClient",
|
||||
return_value=mock_async_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.torchaudio.load",
|
||||
return_value=(MagicMock(), 16000),
|
||||
),
|
||||
):
|
||||
result = await FileDiarizationPyannoteProcessor._diarize(proc, mock_input)
|
||||
|
||||
assert len(result.diarization) == 2
|
||||
assert result.diarization[0] == {"start": 0.123, "end": 1.789, "speaker": 0}
|
||||
assert result.diarization[1] == {"start": 2.0, "end": 3.5, "speaker": 1}
|
||||
Reference in New Issue
Block a user