From 0353c23a94b292fde53af2f359d0e61bcba5096e Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Tue, 10 Feb 2026 22:50:11 -0500 Subject: [PATCH] feat: add local pyannote file diarization processor Enables file diarization without Modal by using pyannote.audio locally. Downloads model bundle from S3 on first use, caches locally, patches config to use local paths. Set DIARIZATION_BACKEND=pyannote to enable. --- .../processors/file_diarization_pyannote.py | 144 +++++++++++++ .../tests/test_file_diarization_pyannote.py | 192 ++++++++++++++++++ 2 files changed, 336 insertions(+) create mode 100644 server/reflector/processors/file_diarization_pyannote.py create mode 100644 server/tests/test_file_diarization_pyannote.py diff --git a/server/reflector/processors/file_diarization_pyannote.py b/server/reflector/processors/file_diarization_pyannote.py new file mode 100644 index 00000000..1e024af2 --- /dev/null +++ b/server/reflector/processors/file_diarization_pyannote.py @@ -0,0 +1,144 @@ +import os +import tarfile +import tempfile +from pathlib import Path + +import httpx +import torch +import torchaudio +import yaml +from pyannote.audio import Pipeline + +from reflector.processors.file_diarization import ( + FileDiarizationInput, + FileDiarizationOutput, + FileDiarizationProcessor, +) +from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor +from reflector.processors.types import DiarizationSegment + +DEFAULT_MODEL_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz" +DEFAULT_CACHE_DIR = "/tmp/pyannote-cache" + + +class FileDiarizationPyannoteProcessor(FileDiarizationProcessor): + """File diarization using local pyannote.audio pipeline. + + Downloads model bundle from URL (or uses HuggingFace), runs speaker diarization. + """ + + def __init__( + self, + pyannote_model_url: str = DEFAULT_MODEL_URL, + pyannote_model_name: str | None = None, + pyannote_auth_token: str | None = None, + pyannote_device: str | None = None, + pyannote_cache_dir: str = DEFAULT_CACHE_DIR, + **kwargs, + ): + super().__init__(**kwargs) + self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN") + self.device = pyannote_device or ( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + if pyannote_model_name: + model_path = pyannote_model_name + else: + model_path = self._ensure_model( + pyannote_model_url, Path(pyannote_cache_dir) + ) + + self.logger.info("Loading pyannote model", model=model_path, device=self.device) + # from_pretrained needs a file path (config.yaml) for local models, + # or a HuggingFace repo ID for remote ones + config_path = Path(model_path) / "config.yaml" + load_path = str(config_path) if config_path.is_file() else model_path + self.diarization_pipeline = Pipeline.from_pretrained( + load_path, use_auth_token=self.auth_token + ) + self.diarization_pipeline.to(torch.device(self.device)) + + def _ensure_model(self, model_url: str, cache_dir: Path) -> str: + """Download and extract model bundle if not cached.""" + model_dir = cache_dir / "pyannote-speaker-diarization-3.1" + config_path = model_dir / "config.yaml" + + if config_path.exists(): + self.logger.info("Using cached model", path=str(model_dir)) + return str(model_dir) + + cache_dir.mkdir(parents=True, exist_ok=True) + tarball_path = cache_dir / "model.tar.gz" + + self.logger.info("Downloading model bundle", url=model_url) + with httpx.Client() as client: + with client.stream("GET", model_url, follow_redirects=True) as response: + response.raise_for_status() + with open(tarball_path, "wb") as f: + for chunk in response.iter_bytes(chunk_size=8192): + f.write(chunk) + + self.logger.info("Extracting model bundle") + with tarfile.open(tarball_path, "r:gz") as tar: + tar.extractall(path=cache_dir, filter="data") + tarball_path.unlink() + + self._patch_config(model_dir, cache_dir) + return str(model_dir) + + def _patch_config(self, model_dir: Path, cache_dir: Path) -> None: + """Rewrite config.yaml to reference local model paths.""" + config_path = model_dir / "config.yaml" + with open(config_path) as f: + config = yaml.safe_load(f) + + config["pipeline"]["params"]["segmentation"] = str( + cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin" + ) + config["pipeline"]["params"]["embedding"] = str( + cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin" + ) + + with open(config_path, "w") as f: + yaml.dump(config, f) + + self.logger.info("Patched config.yaml with local model paths") + + async def _diarize(self, data: FileDiarizationInput) -> FileDiarizationOutput: + self.logger.info("Downloading audio for diarization", audio_url=data.audio_url) + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmp: + async with httpx.AsyncClient() as client: + response = await client.get(data.audio_url, follow_redirects=True) + response.raise_for_status() + tmp.write(response.content) + tmp.flush() + + waveform, sample_rate = torchaudio.load(tmp.name) + + audio_input = {"waveform": waveform, "sample_rate": sample_rate} + diarization = self.diarization_pipeline(audio_input) + + segments: list[DiarizationSegment] = [] + for segment, _, speaker in diarization.itertracks(yield_label=True): + speaker_id = 0 + if speaker.startswith("SPEAKER_"): + try: + speaker_id = int(speaker.split("_")[-1]) + except (ValueError, IndexError): + speaker_id = hash(speaker) % 1000 + + segments.append( + { + "start": round(segment.start, 3), + "end": round(segment.end, 3), + "speaker": speaker_id, + } + ) + + self.logger.info("Diarization complete", segment_count=len(segments)) + return FileDiarizationOutput(diarization=segments) + + +FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor) diff --git a/server/tests/test_file_diarization_pyannote.py b/server/tests/test_file_diarization_pyannote.py new file mode 100644 index 00000000..8270318a --- /dev/null +++ b/server/tests/test_file_diarization_pyannote.py @@ -0,0 +1,192 @@ +import tarfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import yaml + +from reflector.processors.file_diarization_pyannote import ( + FileDiarizationPyannoteProcessor, +) + +ORIGINAL_CONFIG = { + "version": "3.1.0", + "pipeline": { + "name": "pyannote.audio.pipelines.SpeakerDiarization", + "params": { + "clustering": "AgglomerativeClustering", + "embedding": "pyannote/wespeaker-voxceleb-resnet34-LM", + "embedding_batch_size": 32, + "embedding_exclude_overlap": True, + "segmentation": "pyannote/segmentation-3.0", + "segmentation_batch_size": 32, + }, + }, + "params": { + "clustering": { + "method": "centroid", + "min_cluster_size": 12, + "threshold": 0.7045654963945799, + }, + "segmentation": {"min_duration_off": 0.0}, + }, +} + + +def _make_model_tarball(tarball_path: Path) -> None: + """Create a fake model tarball matching real structure.""" + build_dir = tarball_path.parent / "_build" + dirs = { + "pyannote-speaker-diarization-3.1": {"config.yaml": yaml.dump(ORIGINAL_CONFIG)}, + "pyannote-segmentation-3.0": { + "config.yaml": "model: {}\n", + "pytorch_model.bin": b"fake", + }, + "pyannote-wespeaker-voxceleb-resnet34-LM": { + "config.yaml": "model: {}\n", + "pytorch_model.bin": b"fake", + }, + } + for dirname, files in dirs.items(): + d = build_dir / dirname + d.mkdir(parents=True, exist_ok=True) + for fname, content in files.items(): + p = d / fname + if isinstance(content, bytes): + p.write_bytes(content) + else: + p.write_text(content) + + with tarfile.open(tarball_path, "w:gz") as tar: + for dirname in dirs: + tar.add(build_dir / dirname, arcname=dirname) + + +def _make_mock_processor() -> MagicMock: + proc = MagicMock() + proc.logger = MagicMock() + return proc + + +class TestEnsureModel: + """Test model download, extraction, and config patching.""" + + def test_extracts_and_patches_config(self, tmp_path: Path) -> None: + """Downloads tarball, extracts, patches config to local paths.""" + cache_dir = tmp_path / "cache" + tarball_path = tmp_path / "model.tar.gz" + _make_model_tarball(tarball_path) + tarball_bytes = tarball_path.read_bytes() + + mock_response = MagicMock() + mock_response.iter_bytes.return_value = [tarball_bytes] + mock_response.raise_for_status = MagicMock() + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.stream.return_value = mock_response + + proc = _make_mock_processor() + proc._patch_config = lambda model_dir, cache_dir: ( + FileDiarizationPyannoteProcessor._patch_config(proc, model_dir, cache_dir) + ) + + with patch( + "reflector.processors.file_diarization_pyannote.httpx.Client", + return_value=mock_client, + ): + result = FileDiarizationPyannoteProcessor._ensure_model( + proc, "http://fake/model.tar.gz", cache_dir + ) + + assert result == str(cache_dir / "pyannote-speaker-diarization-3.1") + + patched_config_path = ( + cache_dir / "pyannote-speaker-diarization-3.1" / "config.yaml" + ) + with open(patched_config_path) as f: + config = yaml.safe_load(f) + + assert config["pipeline"]["params"]["segmentation"] == str( + cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin" + ) + assert config["pipeline"]["params"]["embedding"] == str( + cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin" + ) + # Non-patched fields preserved + assert config["pipeline"]["params"]["clustering"] == "AgglomerativeClustering" + assert config["params"]["clustering"]["threshold"] == pytest.approx( + 0.7045654963945799 + ) + + def test_uses_cache_on_second_call(self, tmp_path: Path) -> None: + """Skips download if model dir already exists.""" + cache_dir = tmp_path / "cache" + model_dir = cache_dir / "pyannote-speaker-diarization-3.1" + model_dir.mkdir(parents=True) + (model_dir / "config.yaml").write_text("cached: true") + + proc = _make_mock_processor() + + with patch( + "reflector.processors.file_diarization_pyannote.httpx.Client" + ) as mock_httpx: + result = FileDiarizationPyannoteProcessor._ensure_model( + proc, "http://fake/model.tar.gz", cache_dir + ) + mock_httpx.assert_not_called() + + assert result == str(model_dir) + + +class TestDiarizeSegmentParsing: + """Test that pyannote output is correctly converted to DiarizationSegment.""" + + @pytest.mark.asyncio + async def test_parses_speaker_segments(self) -> None: + proc = _make_mock_processor() + + mock_seg_0 = MagicMock() + mock_seg_0.start = 0.123456 + mock_seg_0.end = 1.789012 + mock_seg_1 = MagicMock() + mock_seg_1.start = 2.0 + mock_seg_1.end = 3.5 + + mock_diarization = MagicMock() + mock_diarization.itertracks.return_value = [ + (mock_seg_0, None, "SPEAKER_00"), + (mock_seg_1, None, "SPEAKER_01"), + ] + proc.diarization_pipeline = MagicMock(return_value=mock_diarization) + + mock_input = MagicMock() + mock_input.audio_url = "http://fake/audio.mp3" + + mock_response = AsyncMock() + mock_response.content = b"fake audio" + mock_response.raise_for_status = MagicMock() + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=False) + mock_async_client.get = AsyncMock(return_value=mock_response) + + with ( + patch( + "reflector.processors.file_diarization_pyannote.httpx.AsyncClient", + return_value=mock_async_client, + ), + patch( + "reflector.processors.file_diarization_pyannote.torchaudio.load", + return_value=(MagicMock(), 16000), + ), + ): + result = await FileDiarizationPyannoteProcessor._diarize(proc, mock_input) + + assert len(result.diarization) == 2 + assert result.diarization[0] == {"start": 0.123, "end": 1.789, "speaker": 0} + assert result.diarization[1] == {"start": 2.0, "end": 3.5, "speaker": 1}