mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-15 17:56:54 +00:00
feat: add local pyannote file diarization processor
Enables file diarization without Modal by using pyannote.audio locally. Downloads model bundle from S3 on first use, caches locally, patches config to use local paths. Set DIARIZATION_BACKEND=pyannote to enable.
This commit is contained in:
192
server/tests/test_file_diarization_pyannote.py
Normal file
192
server/tests/test_file_diarization_pyannote.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from reflector.processors.file_diarization_pyannote import (
|
||||
FileDiarizationPyannoteProcessor,
|
||||
)
|
||||
|
||||
ORIGINAL_CONFIG = {
|
||||
"version": "3.1.0",
|
||||
"pipeline": {
|
||||
"name": "pyannote.audio.pipelines.SpeakerDiarization",
|
||||
"params": {
|
||||
"clustering": "AgglomerativeClustering",
|
||||
"embedding": "pyannote/wespeaker-voxceleb-resnet34-LM",
|
||||
"embedding_batch_size": 32,
|
||||
"embedding_exclude_overlap": True,
|
||||
"segmentation": "pyannote/segmentation-3.0",
|
||||
"segmentation_batch_size": 32,
|
||||
},
|
||||
},
|
||||
"params": {
|
||||
"clustering": {
|
||||
"method": "centroid",
|
||||
"min_cluster_size": 12,
|
||||
"threshold": 0.7045654963945799,
|
||||
},
|
||||
"segmentation": {"min_duration_off": 0.0},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_model_tarball(tarball_path: Path) -> None:
|
||||
"""Create a fake model tarball matching real structure."""
|
||||
build_dir = tarball_path.parent / "_build"
|
||||
dirs = {
|
||||
"pyannote-speaker-diarization-3.1": {"config.yaml": yaml.dump(ORIGINAL_CONFIG)},
|
||||
"pyannote-segmentation-3.0": {
|
||||
"config.yaml": "model: {}\n",
|
||||
"pytorch_model.bin": b"fake",
|
||||
},
|
||||
"pyannote-wespeaker-voxceleb-resnet34-LM": {
|
||||
"config.yaml": "model: {}\n",
|
||||
"pytorch_model.bin": b"fake",
|
||||
},
|
||||
}
|
||||
for dirname, files in dirs.items():
|
||||
d = build_dir / dirname
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
for fname, content in files.items():
|
||||
p = d / fname
|
||||
if isinstance(content, bytes):
|
||||
p.write_bytes(content)
|
||||
else:
|
||||
p.write_text(content)
|
||||
|
||||
with tarfile.open(tarball_path, "w:gz") as tar:
|
||||
for dirname in dirs:
|
||||
tar.add(build_dir / dirname, arcname=dirname)
|
||||
|
||||
|
||||
def _make_mock_processor() -> MagicMock:
|
||||
proc = MagicMock()
|
||||
proc.logger = MagicMock()
|
||||
return proc
|
||||
|
||||
|
||||
class TestEnsureModel:
|
||||
"""Test model download, extraction, and config patching."""
|
||||
|
||||
def test_extracts_and_patches_config(self, tmp_path: Path) -> None:
|
||||
"""Downloads tarball, extracts, patches config to local paths."""
|
||||
cache_dir = tmp_path / "cache"
|
||||
tarball_path = tmp_path / "model.tar.gz"
|
||||
_make_model_tarball(tarball_path)
|
||||
tarball_bytes = tarball_path.read_bytes()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_bytes.return_value = [tarball_bytes]
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.__enter__ = MagicMock(return_value=mock_response)
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
mock_client.stream.return_value = mock_response
|
||||
|
||||
proc = _make_mock_processor()
|
||||
proc._patch_config = lambda model_dir, cache_dir: (
|
||||
FileDiarizationPyannoteProcessor._patch_config(proc, model_dir, cache_dir)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_diarization_pyannote.httpx.Client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
result = FileDiarizationPyannoteProcessor._ensure_model(
|
||||
proc, "http://fake/model.tar.gz", cache_dir
|
||||
)
|
||||
|
||||
assert result == str(cache_dir / "pyannote-speaker-diarization-3.1")
|
||||
|
||||
patched_config_path = (
|
||||
cache_dir / "pyannote-speaker-diarization-3.1" / "config.yaml"
|
||||
)
|
||||
with open(patched_config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
assert config["pipeline"]["params"]["segmentation"] == str(
|
||||
cache_dir / "pyannote-segmentation-3.0" / "pytorch_model.bin"
|
||||
)
|
||||
assert config["pipeline"]["params"]["embedding"] == str(
|
||||
cache_dir / "pyannote-wespeaker-voxceleb-resnet34-LM" / "pytorch_model.bin"
|
||||
)
|
||||
# Non-patched fields preserved
|
||||
assert config["pipeline"]["params"]["clustering"] == "AgglomerativeClustering"
|
||||
assert config["params"]["clustering"]["threshold"] == pytest.approx(
|
||||
0.7045654963945799
|
||||
)
|
||||
|
||||
def test_uses_cache_on_second_call(self, tmp_path: Path) -> None:
|
||||
"""Skips download if model dir already exists."""
|
||||
cache_dir = tmp_path / "cache"
|
||||
model_dir = cache_dir / "pyannote-speaker-diarization-3.1"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "config.yaml").write_text("cached: true")
|
||||
|
||||
proc = _make_mock_processor()
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_diarization_pyannote.httpx.Client"
|
||||
) as mock_httpx:
|
||||
result = FileDiarizationPyannoteProcessor._ensure_model(
|
||||
proc, "http://fake/model.tar.gz", cache_dir
|
||||
)
|
||||
mock_httpx.assert_not_called()
|
||||
|
||||
assert result == str(model_dir)
|
||||
|
||||
|
||||
class TestDiarizeSegmentParsing:
|
||||
"""Test that pyannote output is correctly converted to DiarizationSegment."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_speaker_segments(self) -> None:
|
||||
proc = _make_mock_processor()
|
||||
|
||||
mock_seg_0 = MagicMock()
|
||||
mock_seg_0.start = 0.123456
|
||||
mock_seg_0.end = 1.789012
|
||||
mock_seg_1 = MagicMock()
|
||||
mock_seg_1.start = 2.0
|
||||
mock_seg_1.end = 3.5
|
||||
|
||||
mock_diarization = MagicMock()
|
||||
mock_diarization.itertracks.return_value = [
|
||||
(mock_seg_0, None, "SPEAKER_00"),
|
||||
(mock_seg_1, None, "SPEAKER_01"),
|
||||
]
|
||||
proc.diarization_pipeline = MagicMock(return_value=mock_diarization)
|
||||
|
||||
mock_input = MagicMock()
|
||||
mock_input.audio_url = "http://fake/audio.mp3"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.content = b"fake audio"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.httpx.AsyncClient",
|
||||
return_value=mock_async_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.torchaudio.load",
|
||||
return_value=(MagicMock(), 16000),
|
||||
),
|
||||
):
|
||||
result = await FileDiarizationPyannoteProcessor._diarize(proc, mock_input)
|
||||
|
||||
assert len(result.diarization) == 2
|
||||
assert result.diarization[0] == {"start": 0.123, "end": 1.789, "speaker": 0}
|
||||
assert result.diarization[1] == {"start": 2.0, "end": 3.5, "speaker": 1}
|
||||
Reference in New Issue
Block a user