diff --git a/server/reflector/processors/file_diarization_pyannote.py b/server/reflector/processors/file_diarization_pyannote.py deleted file mode 100644 index 1e024af2..00000000 --- a/server/reflector/processors/file_diarization_pyannote.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -import tarfile -import tempfile -from pathlib import Path - -import httpx -import torch -import torchaudio -import yaml -from pyannote.audio import Pipeline - -from reflector.processors.file_diarization import ( - FileDiarizationInput, - FileDiarizationOutput, - FileDiarizationProcessor, -) -from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor -from reflector.processors.types import DiarizationSegment - -DEFAULT_MODEL_URL = "https://reflector-public.s3.us-east-1.amazonaws.com/pyannote-speaker-diarization-3.1.tar.gz" -DEFAULT_CACHE_DIR = "/tmp/pyannote-cache" - - -class FileDiarizationPyannoteProcessor(FileDiarizationProcessor): - """File diarization using local pyannote.audio pipeline. - - Downloads model bundle from URL (or uses HuggingFace), runs speaker diarization. - """ - - def __init__( - self, - pyannote_model_url: str = DEFAULT_MODEL_URL, - pyannote_model_name: str | None = None, - pyannote_auth_token: str | None = None, - pyannote_device: str | None = None, - pyannote_cache_dir: str = DEFAULT_CACHE_DIR, - **kwargs, - ): - super().__init__(**kwargs) - self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN") - self.device = pyannote_device or ( - "cuda" if torch.cuda.is_available() else "cpu" - ) - - if pyannote_model_name: - model_path = pyannote_model_name - else: - model_path = self._ensure_model( - pyannote_model_url, Path(pyannote_cache_dir) - ) - - self.logger.info("Loading pyannote model", model=model_path, device=self.device) - # from_pretrained needs a file path (config.yaml) for local models, - # or a HuggingFace repo ID for remote ones - config_path = Path(model_path) / "config.yaml" - load_path = str(config_path) if config_path.is_file() else model_path - self.diarization_pipeline = Pipeline.from_pretrained( - load_path, use_auth_token=self.auth_token - ) - self.diarization_pipeline.to(torch.device(self.device)) - - def _ensure_model(self, model_url: str, cache_dir: Path) -> str: - """Download and extract model bundle if not cached.""" - model_dir = cache_dir / "pyannote-speaker-diarization-3.1" - config_path = model_dir / "config.yaml" - - if config_path.exists(): - self.logger.info("Using cached model", path=str(model_dir)) - return str(model_dir) - - cache_dir.mkdir(parents=True, exist_ok=True) - tarball_path = cache_dir / "model.tar.gz" - - self.logger.info("Downloading model bundle", url=model_url) - with httpx.Client() as client: - with client.stream("GET", model_url, follow_redirects=True) as response: - response.raise_for_status() - with open(tarball_path, "wb") as f: - for chunk in response.iter_bytes(chunk_size=8192): - f.write(chunk) - - self.logger.info("Extracting model bundle") - with tarfile.open(tarball_path, "r:gz") as tar: - tar.extractall(path=cache_dir, filter="data") - tarball_path.unlink() - - self._patch_config(model_dir, cache_dir) - return str(model_dir) - - def _patch_config(self, model_dir: Path, cache_dir: Path) -> None: - """Rewrite config.yaml to reference local model paths.""" - config_path = model_dir / "config.yaml" - with open(config_path) as f: - config = yaml.safe_load(f) - - config["pipeline"]["params"]["segmentation"] = str( - cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin" - ) - config["pipeline"]["params"]["embedding"] = str( - cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin" - ) - - with open(config_path, "w") as f: - yaml.dump(config, f) - - self.logger.info("Patched config.yaml with local model paths") - - async def _diarize(self, data: FileDiarizationInput) -> FileDiarizationOutput: - self.logger.info("Downloading audio for diarization", audio_url=data.audio_url) - - with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmp: - async with httpx.AsyncClient() as client: - response = await client.get(data.audio_url, follow_redirects=True) - response.raise_for_status() - tmp.write(response.content) - tmp.flush() - - waveform, sample_rate = torchaudio.load(tmp.name) - - audio_input = {"waveform": waveform, "sample_rate": sample_rate} - diarization = self.diarization_pipeline(audio_input) - - segments: list[DiarizationSegment] = [] - for segment, _, speaker in diarization.itertracks(yield_label=True): - speaker_id = 0 - if speaker.startswith("SPEAKER_"): - try: - speaker_id = int(speaker.split("_")[-1]) - except (ValueError, IndexError): - speaker_id = hash(speaker) % 1000 - - segments.append( - { - "start": round(segment.start, 3), - "end": round(segment.end, 3), - "speaker": speaker_id, - } - ) - - self.logger.info("Diarization complete", segment_count=len(segments)) - return FileDiarizationOutput(diarization=segments) - - -FileDiarizationAutoProcessor.register("pyannote", FileDiarizationPyannoteProcessor) diff --git a/server/tests/test_file_diarization_pyannote.py b/server/tests/test_file_diarization_pyannote.py deleted file mode 100644 index 8270318a..00000000 --- a/server/tests/test_file_diarization_pyannote.py +++ /dev/null @@ -1,192 +0,0 @@ -import tarfile -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -import yaml - -from reflector.processors.file_diarization_pyannote import ( - FileDiarizationPyannoteProcessor, -) - -ORIGINAL_CONFIG = { - "version": "3.1.0", - "pipeline": { - "name": "pyannote.audio.pipelines.SpeakerDiarization", - "params": { - "clustering": "AgglomerativeClustering", - "embedding": "pyannote/wespeaker-voxceleb-resnet34-LM", - "embedding_batch_size": 32, - "embedding_exclude_overlap": True, - "segmentation": "pyannote/segmentation-3.0", - "segmentation_batch_size": 32, - }, - }, - "params": { - "clustering": { - "method": "centroid", - "min_cluster_size": 12, - "threshold": 0.7045654963945799, - }, - "segmentation": {"min_duration_off": 0.0}, - }, -} - - -def _make_model_tarball(tarball_path: Path) -> None: - """Create a fake model tarball matching real structure.""" - build_dir = tarball_path.parent / "_build" - dirs = { - "pyannote-speaker-diarization-3.1": {"config.yaml": yaml.dump(ORIGINAL_CONFIG)}, - "pyannote-segmentation-3.0": { - "config.yaml": "model: {}\n", - "pytorch_model.bin": b"fake", - }, - "pyannote-wespeaker-voxceleb-resnet34-LM": { - "config.yaml": "model: {}\n", - "pytorch_model.bin": b"fake", - }, - } - for dirname, files in dirs.items(): - d = build_dir / dirname - d.mkdir(parents=True, exist_ok=True) - for fname, content in files.items(): - p = d / fname - if isinstance(content, bytes): - p.write_bytes(content) - else: - p.write_text(content) - - with tarfile.open(tarball_path, "w:gz") as tar: - for dirname in dirs: - tar.add(build_dir / dirname, arcname=dirname) - - -def _make_mock_processor() -> MagicMock: - proc = MagicMock() - proc.logger = MagicMock() - return proc - - -class TestEnsureModel: - """Test model download, extraction, and config patching.""" - - def test_extracts_and_patches_config(self, tmp_path: Path) -> None: - """Downloads tarball, extracts, patches config to local paths.""" - cache_dir = tmp_path / "cache" - tarball_path = tmp_path / "model.tar.gz" - _make_model_tarball(tarball_path) - tarball_bytes = tarball_path.read_bytes() - - mock_response = MagicMock() - mock_response.iter_bytes.return_value = [tarball_bytes] - mock_response.raise_for_status = MagicMock() - mock_response.__enter__ = MagicMock(return_value=mock_response) - mock_response.__exit__ = MagicMock(return_value=False) - - mock_client = MagicMock() - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=False) - mock_client.stream.return_value = mock_response - - proc = _make_mock_processor() - proc._patch_config = lambda model_dir, cache_dir: ( - FileDiarizationPyannoteProcessor._patch_config(proc, model_dir, cache_dir) - ) - - with patch( - "reflector.processors.file_diarization_pyannote.httpx.Client", - return_value=mock_client, - ): - result = FileDiarizationPyannoteProcessor._ensure_model( - proc, "http://fake/model.tar.gz", cache_dir - ) - - assert result == str(cache_dir / "pyannote-speaker-diarization-3.1") - - patched_config_path = ( - cache_dir / "pyannote-speaker-diarization-3.1" / "config.yaml" - ) - with open(patched_config_path) as f: - config = yaml.safe_load(f) - - assert config["pipeline"]["params"]["segmentation"] == str( - cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin" - ) - assert config["pipeline"]["params"]["embedding"] == str( - cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin" - ) - # Non-patched fields preserved - assert config["pipeline"]["params"]["clustering"] == "AgglomerativeClustering" - assert config["params"]["clustering"]["threshold"] == pytest.approx( - 0.7045654963945799 - ) - - def test_uses_cache_on_second_call(self, tmp_path: Path) -> None: - """Skips download if model dir already exists.""" - cache_dir = tmp_path / "cache" - model_dir = cache_dir / "pyannote-speaker-diarization-3.1" - model_dir.mkdir(parents=True) - (model_dir / "config.yaml").write_text("cached: true") - - proc = _make_mock_processor() - - with patch( - "reflector.processors.file_diarization_pyannote.httpx.Client" - ) as mock_httpx: - result = FileDiarizationPyannoteProcessor._ensure_model( - proc, "http://fake/model.tar.gz", cache_dir - ) - mock_httpx.assert_not_called() - - assert result == str(model_dir) - - -class TestDiarizeSegmentParsing: - """Test that pyannote output is correctly converted to DiarizationSegment.""" - - @pytest.mark.asyncio - async def test_parses_speaker_segments(self) -> None: - proc = _make_mock_processor() - - mock_seg_0 = MagicMock() - mock_seg_0.start = 0.123456 - mock_seg_0.end = 1.789012 - mock_seg_1 = MagicMock() - mock_seg_1.start = 2.0 - mock_seg_1.end = 3.5 - - mock_diarization = MagicMock() - mock_diarization.itertracks.return_value = [ - (mock_seg_0, None, "SPEAKER_00"), - (mock_seg_1, None, "SPEAKER_01"), - ] - proc.diarization_pipeline = MagicMock(return_value=mock_diarization) - - mock_input = MagicMock() - mock_input.audio_url = "http://fake/audio.mp3" - - mock_response = AsyncMock() - mock_response.content = b"fake audio" - mock_response.raise_for_status = MagicMock() - - mock_async_client = AsyncMock() - mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) - mock_async_client.__aexit__ = AsyncMock(return_value=False) - mock_async_client.get = AsyncMock(return_value=mock_response) - - with ( - patch( - "reflector.processors.file_diarization_pyannote.httpx.AsyncClient", - return_value=mock_async_client, - ), - patch( - "reflector.processors.file_diarization_pyannote.torchaudio.load", - return_value=(MagicMock(), 16000), - ), - ): - result = await FileDiarizationPyannoteProcessor._diarize(proc, mock_input) - - assert len(result.diarization) == 2 - assert result.diarization[0] == {"start": 0.123, "end": 1.789, "speaker": 0} - assert result.diarization[1] == {"start": 2.0, "end": 3.5, "speaker": 1}