feat: add local pyannote file diarization processor

Enables file diarization without Modal by using pyannote.audio locally.
Downloads model bundle from S3 on first use, caches locally, patches
config to use local paths. Set DIARIZATION_BACKEND=pyannote to enable.
This commit is contained in:
Igor Loskutov
2026-02-10 22:50:11 -05:00
parent 7372f80530
commit 0353c23a94
2 changed files with 336 additions and 0 deletions

View File

@@ -0,0 +1,144 @@
import os
import tarfile
import tempfile
from pathlib import Path
import httpx
import torch
import torchaudio
import yaml
from pyannote.audio import Pipeline
from reflector.processors.file_diarization import (
FileDiarizationInput,
FileDiarizationOutput,
FileDiarizationProcessor,
)
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
from reflector.processors.types import DiarizationSegment
DEFAULT_MODEL_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz"
DEFAULT_CACHE_DIR = "/tmp/pyannote-cache"
class FileDiarizationPyannoteProcessor(FileDiarizationProcessor):
"""File diarization using local pyannote.audio pipeline.
Downloads model bundle from URL (or uses HuggingFace), runs speaker diarization.
"""
def __init__(
self,
pyannote_model_url: str = DEFAULT_MODEL_URL,
pyannote_model_name: str | None = None,
pyannote_auth_token: str | None = None,
pyannote_device: str | None = None,
pyannote_cache_dir: str = DEFAULT_CACHE_DIR,
**kwargs,
):
super().__init__(**kwargs)
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
self.device = pyannote_device or (
"cuda" if torch.cuda.is_available() else "cpu"
)
if pyannote_model_name:
model_path = pyannote_model_name
else:
model_path = self._ensure_model(
pyannote_model_url, Path(pyannote_cache_dir)
)
self.logger.info("Loading pyannote model", model=model_path, device=self.device)
# from_pretrained needs a file path (config.yaml) for local models,
# or a HuggingFace repo ID for remote ones
config_path = Path(model_path) / "config.yaml"
load_path = str(config_path) if config_path.is_file() else model_path
self.diarization_pipeline = Pipeline.from_pretrained(
load_path, use_auth_token=self.auth_token
)
self.diarization_pipeline.to(torch.device(self.device))
def _ensure_model(self, model_url: str, cache_dir: Path) -> str:
"""Download and extract model bundle if not cached."""
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
config_path = model_dir / "config.yaml"
if config_path.exists():
self.logger.info("Using cached model", path=str(model_dir))
return str(model_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
tarball_path = cache_dir / "model.tar.gz"
self.logger.info("Downloading model bundle", url=model_url)
with httpx.Client() as client:
with client.stream("GET", model_url, follow_redirects=True) as response:
response.raise_for_status()
with open(tarball_path, "wb") as f:
for chunk in response.iter_bytes(chunk_size=8192):
f.write(chunk)
self.logger.info("Extracting model bundle")
with tarfile.open(tarball_path, "r:gz") as tar:
tar.extractall(path=cache_dir, filter="data")
tarball_path.unlink()
self._patch_config(model_dir, cache_dir)
return str(model_dir)
def _patch_config(self, model_dir: Path, cache_dir: Path) -> None:
"""Rewrite config.yaml to reference local model paths."""
config_path = model_dir / "config.yaml"
with open(config_path) as f:
config = yaml.safe_load(f)
config["pipeline"]["params"]["segmentation"] = str(
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
)
config["pipeline"]["params"]["embedding"] = str(
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
)
with open(config_path, "w") as f:
yaml.dump(config, f)
self.logger.info("Patched config.yaml with local model paths")
async def _diarize(self, data: FileDiarizationInput) -> FileDiarizationOutput:
self.logger.info("Downloading audio for diarization", audio_url=data.audio_url)
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmp:
async with httpx.AsyncClient() as client:
response = await client.get(data.audio_url, follow_redirects=True)
response.raise_for_status()
tmp.write(response.content)
tmp.flush()
waveform, sample_rate = torchaudio.load(tmp.name)
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
diarization = self.diarization_pipeline(audio_input)
segments: list[DiarizationSegment] = []
for segment, _, speaker in diarization.itertracks(yield_label=True):
speaker_id = 0
if speaker.startswith("SPEAKER_"):
try:
speaker_id = int(speaker.split("_")[-1])
except (ValueError, IndexError):
speaker_id = hash(speaker) % 1000
segments.append(
{
"start": round(segment.start, 3),
"end": round(segment.end, 3),
"speaker": speaker_id,
}
)
self.logger.info("Diarization complete", segment_count=len(segments))
return FileDiarizationOutput(diarization=segments)
FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor)

View File

@@ -0,0 +1,192 @@
import tarfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
from reflector.processors.file_diarization_pyannote import (
FileDiarizationPyannoteProcessor,
)
ORIGINAL_CONFIG = {
"version": "3.1.0",
"pipeline": {
"name": "pyannote.audio.pipelines.SpeakerDiarization",
"params": {
"clustering": "AgglomerativeClustering",
"embedding": "pyannote/wespeaker-voxceleb-resnet34-LM",
"embedding_batch_size": 32,
"embedding_exclude_overlap": True,
"segmentation": "pyannote/segmentation-3.0",
"segmentation_batch_size": 32,
},
},
"params": {
"clustering": {
"method": "centroid",
"min_cluster_size": 12,
"threshold": 0.7045654963945799,
},
"segmentation": {"min_duration_off": 0.0},
},
}
def _make_model_tarball(tarball_path: Path) -> None:
"""Create a fake model tarball matching real structure."""
build_dir = tarball_path.parent / "_build"
dirs = {
"pyannote-speaker-diarization-3.1": {"config.yaml": yaml.dump(ORIGINAL_CONFIG)},
"pyannote-segmentation-3.0": {
"config.yaml": "model: {}\n",
"pytorch_model.bin": b"fake",
},
"pyannote-wespeaker-voxceleb-resnet34-LM": {
"config.yaml": "model: {}\n",
"pytorch_model.bin": b"fake",
},
}
for dirname, files in dirs.items():
d = build_dir / dirname
d.mkdir(parents=True, exist_ok=True)
for fname, content in files.items():
p = d / fname
if isinstance(content, bytes):
p.write_bytes(content)
else:
p.write_text(content)
with tarfile.open(tarball_path, "w:gz") as tar:
for dirname in dirs:
tar.add(build_dir / dirname, arcname=dirname)
def _make_mock_processor() -> MagicMock:
proc = MagicMock()
proc.logger = MagicMock()
return proc
class TestEnsureModel:
"""Test model download, extraction, and config patching."""
def test_extracts_and_patches_config(self, tmp_path: Path) -> None:
"""Downloads tarball, extracts, patches config to local paths."""
cache_dir = tmp_path / "cache"
tarball_path = tmp_path / "model.tar.gz"
_make_model_tarball(tarball_path)
tarball_bytes = tarball_path.read_bytes()
mock_response = MagicMock()
mock_response.iter_bytes.return_value = [tarball_bytes]
mock_response.raise_for_status = MagicMock()
mock_response.__enter__ = MagicMock(return_value=mock_response)
mock_response.__exit__ = MagicMock(return_value=False)
mock_client = MagicMock()
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
mock_client.stream.return_value = mock_response
proc = _make_mock_processor()
proc._patch_config = lambda model_dir, cache_dir: (
FileDiarizationPyannoteProcessor._patch_config(proc, model_dir, cache_dir)
)
with patch(
"reflector.processors.file_diarization_pyannote.httpx.Client",
return_value=mock_client,
):
result = FileDiarizationPyannoteProcessor._ensure_model(
proc, "http://fake/model.tar.gz", cache_dir
)
assert result == str(cache_dir / "pyannote-speaker-diarization-3.1")
patched_config_path = (
cache_dir / "pyannote-speaker-diarization-3.1" / "config.yaml"
)
with open(patched_config_path) as f:
config = yaml.safe_load(f)
assert config["pipeline"]["params"]["segmentation"] == str(
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
)
assert config["pipeline"]["params"]["embedding"] == str(
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
)
# Non-patched fields preserved
assert config["pipeline"]["params"]["clustering"] == "AgglomerativeClustering"
assert config["params"]["clustering"]["threshold"] == pytest.approx(
0.7045654963945799
)
def test_uses_cache_on_second_call(self, tmp_path: Path) -> None:
"""Skips download if model dir already exists."""
cache_dir = tmp_path / "cache"
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
model_dir.mkdir(parents=True)
(model_dir / "config.yaml").write_text("cached: true")
proc = _make_mock_processor()
with patch(
"reflector.processors.file_diarization_pyannote.httpx.Client"
) as mock_httpx:
result = FileDiarizationPyannoteProcessor._ensure_model(
proc, "http://fake/model.tar.gz", cache_dir
)
mock_httpx.assert_not_called()
assert result == str(model_dir)
class TestDiarizeSegmentParsing:
"""Test that pyannote output is correctly converted to DiarizationSegment."""
@pytest.mark.asyncio
async def test_parses_speaker_segments(self) -> None:
proc = _make_mock_processor()
mock_seg_0 = MagicMock()
mock_seg_0.start = 0.123456
mock_seg_0.end = 1.789012
mock_seg_1 = MagicMock()
mock_seg_1.start = 2.0
mock_seg_1.end = 3.5
mock_diarization = MagicMock()
mock_diarization.itertracks.return_value = [
(mock_seg_0, None, "SPEAKER_00"),
(mock_seg_1, None, "SPEAKER_01"),
]
proc.diarization_pipeline = MagicMock(return_value=mock_diarization)
mock_input = MagicMock()
mock_input.audio_url = "http://fake/audio.mp3"
mock_response = AsyncMock()
mock_response.content = b"fake audio"
mock_response.raise_for_status = MagicMock()
mock_async_client = AsyncMock()
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
mock_async_client.__aexit__ = AsyncMock(return_value=False)
mock_async_client.get = AsyncMock(return_value=mock_response)
with (
patch(
"reflector.processors.file_diarization_pyannote.httpx.AsyncClient",
return_value=mock_async_client,
),
patch(
"reflector.processors.file_diarization_pyannote.torchaudio.load",
return_value=(MagicMock(), 16000),
),
):
result = await FileDiarizationPyannoteProcessor._diarize(proc, mock_input)
assert len(result.diarization) == 2
assert result.diarization[0] == {"start": 0.123, "end": 1.789, "speaker": 0}
assert result.diarization[1] == {"start": 2.0, "end": 3.5, "speaker": 1}