diff --git a/server/reflector/processors/file_diarization_pyannote.py b/server/reflector/processors/file_diarization_pyannote.py new file mode 100644 index 00000000..1e024af2 --- /dev/null +++ b/server/reflector/processors/file_diarization_pyannote.py @@ -0,0 +1,144 @@ +import os +import tarfile +import tempfile +from pathlib import Path + +import httpx +import torch +import torchaudio +import yaml +from pyannote.audio import Pipeline + +from reflector.processors.file_diarization import ( + FileDiarizationInput, + FileDiarizationOutput, + FileDiarizationProcessor, +) +from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor +from reflector.processors.types import DiarizationSegment + +DEFAULT_MODEL_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz" +DEFAULT_CACHE_DIR = "/tmp/pyannote-cache" + + +class FileDiarizationPyannoteProcessor(FileDiarizationProcessor): + """File diarization using local pyannote.audio pipeline. + + Downloads model bundle from URL (or uses HuggingFace), runs speaker diarization. + """ + + def __init__( + self, + pyannote_model_url: str = DEFAULT_MODEL_URL, + pyannote_model_name: str | None = None, + pyannote_auth_token: str | None = None, + pyannote_device: str | None = None, + pyannote_cache_dir: str = DEFAULT_CACHE_DIR, + **kwargs, + ): + super().__init__(**kwargs) + self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN") + self.device = pyannote_device or ( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + if pyannote_model_name: + model_path = pyannote_model_name + else: + model_path = self._ensure_model( + pyannote_model_url, Path(pyannote_cache_dir) + ) + + self.logger.info("Loading pyannote model", model=model_path, device=self.device) + # from_pretrained needs a file path (config.yaml) for local models, + # or a HuggingFace repo ID for remote ones + config_path = Path(model_path) / "config.yaml" + load_path = str(config_path) if config_path.is_file() else model_path + self.diarization_pipeline = Pipeline.from_pretrained( + load_path, use_auth_token=self.auth_token + ) + self.diarization_pipeline.to(torch.device(self.device)) + + def _ensure_model(self, model_url: str, cache_dir: Path) -> str: + """Download and extract model bundle if not cached.""" + model_dir = cache_dir / "pyannote-speaker-diarization-3.1" + config_path = model_dir / "config.yaml" + + if config_path.exists(): + self.logger.info("Using cached model", path=str(model_dir)) + return str(model_dir) + + cache_dir.mkdir(parents=True, exist_ok=True) + tarball_path = cache_dir / "model.tar.gz" + + self.logger.info("Downloading model bundle", url=model_url) + with httpx.Client() as client: + with client.stream("GET", model_url, follow_redirects=True) as response: + response.raise_for_status() + with open(tarball_path, "wb") as f: + for chunk in response.iter_bytes(chunk_size=8192): + f.write(chunk) + + self.logger.info("Extracting model bundle") + with tarfile.open(tarball_path, "r:gz") as tar: + tar.extractall(path=cache_dir, filter="data") + tarball_path.unlink() + + self._patch_config(model_dir, cache_dir) + return str(model_dir) + + def _patch_config(self, model_dir: Path, cache_dir: Path) -> None: + """Rewrite config.yaml to reference local model paths.""" + config_path = model_dir / "config.yaml" + with open(config_path) as f: + config = yaml.safe_load(f) + + config["pipeline"]["params"]["segmentation"] = str( + cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin" + ) + config["pipeline"]["params"]["embedding"] = str( + cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin" + ) + + with open(config_path, "w") as f: + yaml.dump(config, f) + + self.logger.info("Patched config.yaml with local model paths") + + async def _diarize(self, data: FileDiarizationInput) -> FileDiarizationOutput: + self.logger.info("Downloading audio for diarization", audio_url=data.audio_url) + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmp: + async with httpx.AsyncClient() as client: + response = await client.get(data.audio_url, follow_redirects=True) + response.raise_for_status() + tmp.write(response.content) + tmp.flush() + + waveform, sample_rate = torchaudio.load(tmp.name) + + audio_input = {"waveform": waveform, "sample_rate": sample_rate} + diarization = self.diarization_pipeline(audio_input) + + segments: list[DiarizationSegment] = [] + for segment, _, speaker in diarization.itertracks(yield_label=True): + speaker_id = 0 + if speaker.startswith("SPEAKER_"): + try: + speaker_id = int(speaker.split("_")[-1]) + except (ValueError, IndexError): + speaker_id = hash(speaker) % 1000 + + segments.append( + { + "start": round(segment.start, 3), + "end": round(segment.end, 3), + "speaker": speaker_id, + } + ) + + self.logger.info("Diarization complete", segment_count=len(segments)) + return FileDiarizationOutput(diarization=segments) + + +FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor) diff --git a/server/tests/test_file_diarization_pyannote.py b/server/tests/test_file_diarization_pyannote.py new file mode 100644 index 00000000..8270318a --- /dev/null +++ b/server/tests/test_file_diarization_pyannote.py @@ -0,0 +1,192 @@ +import tarfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import yaml + +from reflector.processors.file_diarization_pyannote import ( + FileDiarizationPyannoteProcessor, +) + +ORIGINAL_CONFIG = { + "version": "3.1.0", + "pipeline": { + "name": "pyannote.audio.pipelines.SpeakerDiarization", + "params": { + "clustering": "AgglomerativeClustering", + "embedding": "pyannote/wespeaker-voxceleb-resnet34-LM", + "embedding_batch_size": 32, + "embedding_exclude_overlap": True, + "segmentation": "pyannote/segmentation-3.0", + "segmentation_batch_size": 32, + }, + }, + "params": { + "clustering": { + "method": "centroid", + "min_cluster_size": 12, + "threshold": 0.7045654963945799, + }, + "segmentation": {"min_duration_off": 0.0}, + }, +} + + +def _make_model_tarball(tarball_path: Path) -> None: + """Create a fake model tarball matching real structure.""" + build_dir = tarball_path.parent / "_build" + dirs = { + "pyannote-speaker-diarization-3.1": {"config.yaml": yaml.dump(ORIGINAL_CONFIG)}, + "pyannote-segmentation-3.0": { + "config.yaml": "model: {}\n", + "pytorch_model.bin": b"fake", + }, + "pyannote-wespeaker-voxceleb-resnet34-LM": { + "config.yaml": "model: {}\n", + "pytorch_model.bin": b"fake", + }, + } + for dirname, files in dirs.items(): + d = build_dir / dirname + d.mkdir(parents=True, exist_ok=True) + for fname, content in files.items(): + p = d / fname + if isinstance(content, bytes): + p.write_bytes(content) + else: + p.write_text(content) + + with tarfile.open(tarball_path, "w:gz") as tar: + for dirname in dirs: + tar.add(build_dir / dirname, arcname=dirname) + + +def _make_mock_processor() -> MagicMock: + proc = MagicMock() + proc.logger = MagicMock() + return proc + + +class TestEnsureModel: + """Test model download, extraction, and config patching.""" + + def test_extracts_and_patches_config(self, tmp_path: Path) -> None: + """Downloads tarball, extracts, patches config to local paths.""" + cache_dir = tmp_path / "cache" + tarball_path = tmp_path / "model.tar.gz" + _make_model_tarball(tarball_path) + tarball_bytes = tarball_path.read_bytes() + + mock_response = MagicMock() + mock_response.iter_bytes.return_value = [tarball_bytes] + mock_response.raise_for_status = MagicMock() + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.stream.return_value = mock_response + + proc = _make_mock_processor() + proc._patch_config = lambda model_dir, cache_dir: ( + FileDiarizationPyannoteProcessor._patch_config(proc, model_dir, cache_dir) + ) + + with patch( + "reflector.processors.file_diarization_pyannote.httpx.Client", + return_value=mock_client, + ): + result = FileDiarizationPyannoteProcessor._ensure_model( + proc, "http://fake/model.tar.gz", cache_dir + ) + + assert result == str(cache_dir / "pyannote-speaker-diarization-3.1") + + patched_config_path = ( + cache_dir / "pyannote-speaker-diarization-3.1" / "config.yaml" + ) + with open(patched_config_path) as f: + config = yaml.safe_load(f) + + assert config["pipeline"]["params"]["segmentation"] == str( + cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin" + ) + assert config["pipeline"]["params"]["embedding"] == str( + cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin" + ) + # Non-patched fields preserved + assert config["pipeline"]["params"]["clustering"] == "AgglomerativeClustering" + assert config["params"]["clustering"]["threshold"] == pytest.approx( + 0.7045654963945799 + ) + + def test_uses_cache_on_second_call(self, tmp_path: Path) -> None: + """Skips download if model dir already exists.""" + cache_dir = tmp_path / "cache" + model_dir = cache_dir / "pyannote-speaker-diarization-3.1" + model_dir.mkdir(parents=True) + (model_dir / "config.yaml").write_text("cached: true") + + proc = _make_mock_processor() + + with patch( + "reflector.processors.file_diarization_pyannote.httpx.Client" + ) as mock_httpx: + result = FileDiarizationPyannoteProcessor._ensure_model( + proc, "http://fake/model.tar.gz", cache_dir + ) + mock_httpx.assert_not_called() + + assert result == str(model_dir) + + +class TestDiarizeSegmentParsing: + """Test that pyannote output is correctly converted to DiarizationSegment.""" + + @pytest.mark.asyncio + async def test_parses_speaker_segments(self) -> None: + proc = _make_mock_processor() + + mock_seg_0 = MagicMock() + mock_seg_0.start = 0.123456 + mock_seg_0.end = 1.789012 + mock_seg_1 = MagicMock() + mock_seg_1.start = 2.0 + mock_seg_1.end = 3.5 + + mock_diarization = MagicMock() + mock_diarization.itertracks.return_value = [ + (mock_seg_0, None, "SPEAKER_00"), + (mock_seg_1, None, "SPEAKER_01"), + ] + proc.diarization_pipeline = MagicMock(return_value=mock_diarization) + + mock_input = MagicMock() + mock_input.audio_url = "http://fake/audio.mp3" + + mock_response = AsyncMock() + mock_response.content = b"fake audio" + mock_response.raise_for_status = MagicMock() + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=False) + mock_async_client.get = AsyncMock(return_value=mock_response) + + with ( + patch( + "reflector.processors.file_diarization_pyannote.httpx.AsyncClient", + return_value=mock_async_client, + ), + patch( + "reflector.processors.file_diarization_pyannote.torchaudio.load", + return_value=(MagicMock(), 16000), + ), + ): + result = await FileDiarizationPyannoteProcessor._diarize(proc, mock_input) + + assert len(result.diarization) == 2 + assert result.diarization[0] == {"start": 0.123, "end": 1.789, "speaker": 0} + assert result.diarization[1] == {"start": 2.0, "end": 3.5, "speaker": 1}