Files
reflector/server/tests/test_file_diarization_pyannote.py
Igor Loskutov 0353c23a94 feat: add local pyannote file diarization processor
Enables file diarization without Modal by using pyannote.audio locally.
Downloads model bundle from S3 on first use, caches locally, patches
config to use local paths. Set DIARIZATION_BACKEND=pyannote to enable.
2026-02-11 13:37:12 -05:00

193 lines
6.7 KiB
Python

import tarfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
from reflector.processors.file_diarization_pyannote import (
FileDiarizationPyannoteProcessor,
)
ORIGINAL_CONFIG = {
"version": "3.1.0",
"pipeline": {
"name": "pyannote.audio.pipelines.SpeakerDiarization",
"params": {
"clustering": "AgglomerativeClustering",
"embedding": "pyannote/wespeaker-voxceleb-resnet34-LM",
"embedding_batch_size": 32,
"embedding_exclude_overlap": True,
"segmentation": "pyannote/segmentation-3.0",
"segmentation_batch_size": 32,
},
},
"params": {
"clustering": {
"method": "centroid",
"min_cluster_size": 12,
"threshold": 0.7045654963945799,
},
"segmentation": {"min_duration_off": 0.0},
},
}
def _make_model_tarball(tarball_path: Path) -> None:
"""Create a fake model tarball matching real structure."""
build_dir = tarball_path.parent / "_build"
dirs = {
"pyannote-speaker-diarization-3.1": {"config.yaml": yaml.dump(ORIGINAL_CONFIG)},
"pyannote-segmentation-3.0": {
"config.yaml": "model: {}\n",
"pytorch_model.bin": b"fake",
},
"pyannote-wespeaker-voxceleb-resnet34-LM": {
"config.yaml": "model: {}\n",
"pytorch_model.bin": b"fake",
},
}
for dirname, files in dirs.items():
d = build_dir / dirname
d.mkdir(parents=True, exist_ok=True)
for fname, content in files.items():
p = d / fname
if isinstance(content, bytes):
p.write_bytes(content)
else:
p.write_text(content)
with tarfile.open(tarball_path, "w:gz") as tar:
for dirname in dirs:
tar.add(build_dir / dirname, arcname=dirname)
def _make_mock_processor() -> MagicMock:
proc = MagicMock()
proc.logger = MagicMock()
return proc
class TestEnsureModel:
"""Test model download, extraction, and config patching."""
def test_extracts_and_patches_config(self, tmp_path: Path) -> None:
"""Downloads tarball, extracts, patches config to local paths."""
cache_dir = tmp_path / "cache"
tarball_path = tmp_path / "model.tar.gz"
_make_model_tarball(tarball_path)
tarball_bytes = tarball_path.read_bytes()
mock_response = MagicMock()
mock_response.iter_bytes.return_value = [tarball_bytes]
mock_response.raise_for_status = MagicMock()
mock_response.__enter__ = MagicMock(return_value=mock_response)
mock_response.__exit__ = MagicMock(return_value=False)
mock_client = MagicMock()
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
mock_client.stream.return_value = mock_response
proc = _make_mock_processor()
proc._patch_config = lambda model_dir, cache_dir: (
FileDiarizationPyannoteProcessor._patch_config(proc, model_dir, cache_dir)
)
with patch(
"reflector.processors.file_diarization_pyannote.httpx.Client",
return_value=mock_client,
):
result = FileDiarizationPyannoteProcessor._ensure_model(
proc, "http://fake/model.tar.gz", cache_dir
)
assert result == str(cache_dir / "pyannote-speaker-diarization-3.1")
patched_config_path = (
cache_dir / "pyannote-speaker-diarization-3.1" / "config.yaml"
)
with open(patched_config_path) as f:
config = yaml.safe_load(f)
assert config["pipeline"]["params"]["segmentation"] == str(
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
)
assert config["pipeline"]["params"]["embedding"] == str(
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
)
# Non-patched fields preserved
assert config["pipeline"]["params"]["clustering"] == "AgglomerativeClustering"
assert config["params"]["clustering"]["threshold"] == pytest.approx(
0.7045654963945799
)
def test_uses_cache_on_second_call(self, tmp_path: Path) -> None:
"""Skips download if model dir already exists."""
cache_dir = tmp_path / "cache"
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
model_dir.mkdir(parents=True)
(model_dir / "config.yaml").write_text("cached: true")
proc = _make_mock_processor()
with patch(
"reflector.processors.file_diarization_pyannote.httpx.Client"
) as mock_httpx:
result = FileDiarizationPyannoteProcessor._ensure_model(
proc, "http://fake/model.tar.gz", cache_dir
)
mock_httpx.assert_not_called()
assert result == str(model_dir)
class TestDiarizeSegmentParsing:
"""Test that pyannote output is correctly converted to DiarizationSegment."""
@pytest.mark.asyncio
async def test_parses_speaker_segments(self) -> None:
proc = _make_mock_processor()
mock_seg_0 = MagicMock()
mock_seg_0.start = 0.123456
mock_seg_0.end = 1.789012
mock_seg_1 = MagicMock()
mock_seg_1.start = 2.0
mock_seg_1.end = 3.5
mock_diarization = MagicMock()
mock_diarization.itertracks.return_value = [
(mock_seg_0, None, "SPEAKER_00"),
(mock_seg_1, None, "SPEAKER_01"),
]
proc.diarization_pipeline = MagicMock(return_value=mock_diarization)
mock_input = MagicMock()
mock_input.audio_url = "http://fake/audio.mp3"
mock_response = AsyncMock()
mock_response.content = b"fake audio"
mock_response.raise_for_status = MagicMock()
mock_async_client = AsyncMock()
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
mock_async_client.__aexit__ = AsyncMock(return_value=False)
mock_async_client.get = AsyncMock(return_value=mock_response)
with (
patch(
"reflector.processors.file_diarization_pyannote.httpx.AsyncClient",
return_value=mock_async_client,
),
patch(
"reflector.processors.file_diarization_pyannote.torchaudio.load",
return_value=(MagicMock(), 16000),
),
):
result = await FileDiarizationPyannoteProcessor._diarize(proc, mock_input)
assert len(result.diarization) == 2
assert result.diarization[0] == {"start": 0.123, "end": 1.789, "speaker": 0}
assert result.diarization[1] == {"start": 2.0, "end": 3.5, "speaker": 1}