feat: 3-mode selfhosted refactoring (--gpu, --cpu, --hosted) + audio token auth fallback (#896)

* fix: local processing instead of http server for cpu

* add fallback token if service worker doesnt work

* chore: rename processors to keep processor pattern up to date and allow other processors to be createed and used with env vars
This commit is contained in:
Juan Diego García
2026-03-04 16:31:08 -05:00
committed by GitHub
parent 4235ab4293
commit a682846645
34 changed files with 2640 additions and 172 deletions

View File

@@ -0,0 +1,450 @@
"""
Tests for in-process processor backends (--cpu mode).
All ML model calls are mocked — no actual model loading needed.
Tests verify processor registration, wiring, error handling, and data flow.
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from reflector.processors.file_diarization import (
FileDiarizationInput,
FileDiarizationOutput,
)
from reflector.processors.types import (
AudioDiarizationInput,
TitleSummaryWithId,
Transcript,
Word,
)
# ── Registration Tests ──────────────────────────────────────────────────
def test_audio_diarization_pyannote_registers():
"""Verify AudioDiarizationPyannoteProcessor registers with 'pyannote' backend."""
# Importing the module triggers registration
import reflector.processors.audio_diarization_pyannote # noqa: F401
from reflector.processors.audio_diarization_auto import (
AudioDiarizationAutoProcessor,
)
assert "pyannote" in AudioDiarizationAutoProcessor._registry
def test_file_diarization_pyannote_registers():
"""Verify FileDiarizationPyannoteProcessor registers with 'pyannote' backend."""
import reflector.processors.file_diarization_pyannote # noqa: F401
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
assert "pyannote" in FileDiarizationAutoProcessor._registry
def test_transcript_translator_marian_registers():
"""Verify TranscriptTranslatorMarianProcessor registers with 'marian' backend."""
import reflector.processors.transcript_translator_marian # noqa: F401
from reflector.processors.transcript_translator_auto import (
TranscriptTranslatorAutoProcessor,
)
assert "marian" in TranscriptTranslatorAutoProcessor._registry
def test_file_transcript_whisper_registers():
"""Verify FileTranscriptWhisperProcessor registers with 'whisper' backend."""
import reflector.processors.file_transcript_whisper # noqa: F401
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
assert "whisper" in FileTranscriptAutoProcessor._registry
# ── Audio Download Utility Tests ────────────────────────────────────────
@pytest.mark.asyncio
async def test_download_audio_to_temp_success():
"""Verify download_audio_to_temp downloads to a temp file and returns path."""
from reflector.processors._audio_download import download_audio_to_temp
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "audio/wav"}
mock_response.iter_content.return_value = [b"fake audio data"]
mock_response.raise_for_status = MagicMock()
with patch("reflector.processors._audio_download.requests.get") as mock_get:
mock_get.return_value = mock_response
result = await download_audio_to_temp("https://example.com/test.wav")
assert isinstance(result, Path)
assert result.exists()
assert result.read_bytes() == b"fake audio data"
assert result.suffix == ".wav"
# Cleanup
os.unlink(result)
@pytest.mark.asyncio
async def test_download_audio_to_temp_cleanup_on_error():
"""Verify temp file is cleaned up when download fails mid-write."""
from reflector.processors._audio_download import download_audio_to_temp
mock_response = MagicMock()
mock_response.headers = {"content-type": "audio/wav"}
mock_response.raise_for_status = MagicMock()
def fail_iter(*args, **kwargs):
raise ConnectionError("Download interrupted")
mock_response.iter_content = fail_iter
with patch("reflector.processors._audio_download.requests.get") as mock_get:
mock_get.return_value = mock_response
with pytest.raises(ConnectionError, match="Download interrupted"):
await download_audio_to_temp("https://example.com/test.wav")
def test_detect_extension_from_url():
"""Verify extension detection from URL path."""
from reflector.processors._audio_download import _detect_extension
assert _detect_extension("https://example.com/test.wav", "") == ".wav"
assert _detect_extension("https://example.com/test.mp3?signed=1", "") == ".mp3"
assert _detect_extension("https://example.com/test.webm", "") == ".webm"
def test_detect_extension_from_content_type():
"""Verify extension detection from content-type header."""
from reflector.processors._audio_download import _detect_extension
assert _detect_extension("https://s3.aws/uuid", "audio/mpeg") == ".mp3"
assert _detect_extension("https://s3.aws/uuid", "audio/wav") == ".wav"
assert _detect_extension("https://s3.aws/uuid", "audio/webm") == ".webm"
def test_detect_extension_fallback():
"""Verify fallback extension when neither URL nor content-type is recognized."""
from reflector.processors._audio_download import _detect_extension
assert (
_detect_extension("https://s3.aws/uuid", "application/octet-stream") == ".audio"
)
# ── Audio Diarization Pyannote Processor Tests ──────────────────────────
@pytest.mark.asyncio
async def test_audio_diarization_pyannote_diarize():
"""Verify pyannote audio diarization downloads, diarizes, and cleans up."""
from reflector.processors.audio_diarization_pyannote import (
AudioDiarizationPyannoteProcessor,
)
mock_diarization_result = {
"diarization": [
{"start": 0.0, "end": 2.5, "speaker": 0},
{"start": 2.5, "end": 5.0, "speaker": 1},
]
}
# Create a temp file to simulate download
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.write(b"fake audio")
tmp.close()
tmp_path = Path(tmp.name)
processor = AudioDiarizationPyannoteProcessor()
with (
patch(
"reflector.processors.audio_diarization_pyannote.download_audio_to_temp",
new_callable=AsyncMock,
return_value=tmp_path,
),
patch(
"reflector.processors.audio_diarization_pyannote.diarization_service"
) as mock_svc,
):
mock_svc.diarize_file.return_value = mock_diarization_result
data = AudioDiarizationInput(
audio_url="https://example.com/test.wav",
topics=[
TitleSummaryWithId(
id="topic-1",
title="Test Topic",
summary="A test topic",
timestamp=0.0,
duration=5.0,
transcript=Transcript(
words=[Word(text="hello", start=0.0, end=1.0)]
),
)
],
)
result = await processor._diarize(data)
assert result == mock_diarization_result["diarization"]
mock_svc.diarize_file.assert_called_once()
# ── File Diarization Pyannote Processor Tests ───────────────────────────
@pytest.mark.asyncio
async def test_file_diarization_pyannote_diarize():
"""Verify pyannote file diarization returns FileDiarizationOutput."""
from reflector.processors.file_diarization_pyannote import (
FileDiarizationPyannoteProcessor,
)
mock_diarization_result = {
"diarization": [
{"start": 0.0, "end": 3.0, "speaker": 0},
{"start": 3.0, "end": 6.0, "speaker": 1},
]
}
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.write(b"fake audio")
tmp.close()
tmp_path = Path(tmp.name)
processor = FileDiarizationPyannoteProcessor()
with (
patch(
"reflector.processors.file_diarization_pyannote.download_audio_to_temp",
new_callable=AsyncMock,
return_value=tmp_path,
),
patch(
"reflector.processors.file_diarization_pyannote.diarization_service"
) as mock_svc,
):
mock_svc.diarize_file.return_value = mock_diarization_result
data = FileDiarizationInput(audio_url="https://example.com/test.wav")
result = await processor._diarize(data)
assert isinstance(result, FileDiarizationOutput)
assert len(result.diarization) == 2
assert result.diarization[0]["start"] == 0.0
assert result.diarization[1]["speaker"] == 1
# ── Transcript Translator Marian Processor Tests ───────────────────────
@pytest.mark.asyncio
async def test_transcript_translator_marian_translate():
"""Verify MarianMT translator calls service and extracts translation."""
from reflector.processors.transcript_translator_marian import (
TranscriptTranslatorMarianProcessor,
)
mock_result = {"text": {"en": "Hello world", "fr": "Bonjour le monde"}}
processor = TranscriptTranslatorMarianProcessor()
def fake_get_pref(key, default=None):
prefs = {"audio:source_language": "en", "audio:target_language": "fr"}
return prefs.get(key, default)
with (
patch.object(processor, "get_pref", side_effect=fake_get_pref),
patch(
"reflector.processors.transcript_translator_marian.translator_service"
) as mock_svc,
):
mock_svc.translate.return_value = mock_result
result = await processor._translate("Hello world")
assert result == "Bonjour le monde"
mock_svc.translate.assert_called_once_with("Hello world", "en", "fr")
@pytest.mark.asyncio
async def test_transcript_translator_marian_no_translation():
"""Verify translator returns None when target language not in result."""
from reflector.processors.transcript_translator_marian import (
TranscriptTranslatorMarianProcessor,
)
mock_result = {"text": {"en": "Hello world"}}
processor = TranscriptTranslatorMarianProcessor()
def fake_get_pref(key, default=None):
prefs = {"audio:source_language": "en", "audio:target_language": "fr"}
return prefs.get(key, default)
with (
patch.object(processor, "get_pref", side_effect=fake_get_pref),
patch(
"reflector.processors.transcript_translator_marian.translator_service"
) as mock_svc,
):
mock_svc.translate.return_value = mock_result
result = await processor._translate("Hello world")
assert result is None
# ── File Transcript Whisper Processor Tests ─────────────────────────────
@pytest.mark.asyncio
async def test_file_transcript_whisper_transcript():
"""Verify whisper file processor downloads, transcribes, and returns Transcript."""
from reflector.processors.file_transcript import FileTranscriptInput
from reflector.processors.file_transcript_whisper import (
FileTranscriptWhisperProcessor,
)
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.write(b"fake audio")
tmp.close()
tmp_path = Path(tmp.name)
processor = FileTranscriptWhisperProcessor()
# Mock the blocking transcription method
mock_transcript = Transcript(
words=[
Word(text="Hello", start=0.0, end=0.5),
Word(text=" world", start=0.5, end=1.0),
]
)
with (
patch(
"reflector.processors.file_transcript_whisper.download_audio_to_temp",
new_callable=AsyncMock,
return_value=tmp_path,
),
patch.object(
processor,
"_transcribe_file_blocking",
return_value=mock_transcript,
),
):
data = FileTranscriptInput(
audio_url="https://example.com/test.wav", language="en"
)
result = await processor._transcript(data)
assert isinstance(result, Transcript)
assert len(result.words) == 2
assert result.words[0].text == "Hello"
# ── VAD Helper Tests ────────────────────────────────────────────────────
def test_enforce_word_timing_constraints():
"""Verify word timing enforcement prevents overlapping times."""
from reflector.processors.file_transcript_whisper import (
_enforce_word_timing_constraints,
)
words = [
{"word": "hello", "start": 0.0, "end": 1.5},
{"word": "world", "start": 1.0, "end": 2.0}, # overlaps with previous
{"word": "test", "start": 2.0, "end": 3.0},
]
result = _enforce_word_timing_constraints(words)
assert result[0]["end"] == 1.0 # Clamped to next word's start
assert result[1]["end"] == 2.0 # Clamped to next word's start
assert result[2]["end"] == 3.0 # Last word unchanged
def test_enforce_word_timing_constraints_empty():
"""Verify timing enforcement handles empty and single-word lists."""
from reflector.processors.file_transcript_whisper import (
_enforce_word_timing_constraints,
)
assert _enforce_word_timing_constraints([]) == []
assert _enforce_word_timing_constraints([{"word": "a", "start": 0, "end": 1}]) == [
{"word": "a", "start": 0, "end": 1}
]
def test_pad_audio_short():
"""Verify short audio gets padded with silence."""
import numpy as np
from reflector.processors.file_transcript_whisper import _pad_audio
short_audio = np.zeros(100, dtype=np.float32) # Very short
result = _pad_audio(short_audio, sample_rate=16000)
# Should be padded to at least silence_padding duration
assert len(result) > len(short_audio)
def test_pad_audio_long():
"""Verify long audio is not padded."""
import numpy as np
from reflector.processors.file_transcript_whisper import _pad_audio
long_audio = np.zeros(32000, dtype=np.float32) # 2 seconds
result = _pad_audio(long_audio, sample_rate=16000)
assert len(result) == len(long_audio)
# ── Translator Service Tests ────────────────────────────────────────────
def test_translator_service_resolve_model():
"""Verify model resolution for known and unknown language pairs."""
from reflector.processors._marian_translator_service import MarianTranslatorService
svc = MarianTranslatorService()
assert svc._resolve_model_name("en", "fr") == "Helsinki-NLP/opus-mt-en-fr"
assert svc._resolve_model_name("es", "en") == "Helsinki-NLP/opus-mt-es-en"
assert svc._resolve_model_name("en", "de") == "Helsinki-NLP/opus-mt-en-de"
# Unknown pair falls back to en->fr
assert svc._resolve_model_name("ja", "ko") == "Helsinki-NLP/opus-mt-en-fr"
# ── Diarization Service Tests ───────────────────────────────────────────
def test_diarization_service_singleton():
"""Verify diarization_service is a module-level singleton."""
from reflector.processors._pyannote_diarization_service import (
PyannoteDiarizationService,
diarization_service,
)
assert isinstance(diarization_service, PyannoteDiarizationService)
assert diarization_service._pipeline is None # Not loaded until first use
def test_translator_service_singleton():
"""Verify translator_service is a module-level singleton."""
from reflector.processors._marian_translator_service import (
MarianTranslatorService,
translator_service,
)
assert isinstance(translator_service, MarianTranslatorService)
assert translator_service._pipeline is None # Not loaded until first use

View File

@@ -0,0 +1,327 @@
"""Tests for audio mp3 endpoint token query-param authentication.
Covers both password (HS256) and JWT/Authentik (RS256) auth backends,
verifying that private transcripts can be accessed via ?token= query param.
"""
import shutil
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import patch
import jwt
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
OWNER_USER_ID = "test-owner-user-id"
def _create_hs256_token(user_id: str, secret: str, expired: bool = False) -> str:
"""Create an HS256 JWT like the password auth backend does."""
delta = timedelta(minutes=-5) if expired else timedelta(hours=24)
payload = {
"sub": user_id,
"email": "test@example.com",
"exp": datetime.now(timezone.utc) + delta,
}
return jwt.encode(payload, secret, algorithm="HS256")
def _generate_rsa_keypair():
"""Generate a fresh RSA keypair for tests."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_pem = private_key.public_key().public_bytes(
serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
)
return private_key, public_pem.decode()
def _create_rs256_token(
authentik_uid: str,
private_key,
audience: str,
expired: bool = False,
) -> str:
"""Create an RS256 JWT like Authentik would issue."""
delta = timedelta(minutes=-5) if expired else timedelta(hours=1)
payload = {
"sub": authentik_uid,
"email": "authentik-user@example.com",
"aud": audience,
"exp": datetime.now(timezone.utc) + delta,
}
return jwt.encode(payload, private_key, algorithm="RS256")
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def private_transcript(tmpdir):
"""Create a private transcript owned by OWNER_USER_ID with an mp3 file.
Created directly via the controller (not HTTP) so no auth override
leaks into the test scope.
"""
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.settings import settings
settings.DATA_DIR = Path(tmpdir)
transcript = await transcripts_controller.add(
"Private audio test",
source_kind=SourceKind.FILE,
user_id=OWNER_USER_ID,
share_mode="private",
)
await transcripts_controller.update(transcript, {"status": "ended"})
# Copy a real mp3 to the expected location
audio_filename = transcript.audio_mp3_filename
mp3_source = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(mp3_source, audio_filename)
yield transcript
# ---------------------------------------------------------------------------
# Core access control tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_audio_mp3_private_no_auth_returns_403(private_transcript, client):
"""Without auth, accessing a private transcript's audio returns 403."""
response = await client.get(f"/transcripts/{private_transcript.id}/audio/mp3")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_audio_mp3_with_bearer_header(private_transcript, client):
"""Owner accessing audio via Authorization header works."""
from reflector.app import app
from reflector.auth import current_user_optional
# Temporarily override to simulate the owner being authenticated
app.dependency_overrides[current_user_optional] = lambda: {
"sub": OWNER_USER_ID,
"email": "test@example.com",
}
try:
response = await client.get(f"/transcripts/{private_transcript.id}/audio/mp3")
finally:
del app.dependency_overrides[current_user_optional]
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
@pytest.mark.asyncio
async def test_audio_mp3_public_transcript_no_auth_ok(tmpdir, client):
"""Public transcripts are accessible without any auth."""
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.settings import settings
settings.DATA_DIR = Path(tmpdir)
transcript = await transcripts_controller.add(
"Public audio test",
source_kind=SourceKind.FILE,
user_id=OWNER_USER_ID,
share_mode="public",
)
await transcripts_controller.update(transcript, {"status": "ended"})
audio_filename = transcript.audio_mp3_filename
mp3_source = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(mp3_source, audio_filename)
response = await client.get(f"/transcripts/{transcript.id}/audio/mp3")
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
# ---------------------------------------------------------------------------
# Password auth backend tests (?token= with HS256)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_audio_mp3_password_token_query_param(private_transcript, client):
"""Password backend: valid HS256 ?token= grants access to private audio."""
from reflector.auth.auth_password import UserInfo
from reflector.settings import settings
token = _create_hs256_token(OWNER_USER_ID, settings.SECRET_KEY)
with patch("reflector.auth.verify_raw_token") as mock_verify:
mock_verify.return_value = UserInfo(sub=OWNER_USER_ID, email="test@example.com")
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
@pytest.mark.asyncio
async def test_audio_mp3_password_expired_token_returns_401(private_transcript, client):
"""Password backend: expired HS256 ?token= returns 401."""
from reflector.settings import settings
expired_token = _create_hs256_token(
OWNER_USER_ID, settings.SECRET_KEY, expired=True
)
with patch("reflector.auth.verify_raw_token") as mock_verify:
mock_verify.side_effect = jwt.ExpiredSignatureError("token expired")
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3" f"?token={expired_token}"
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_audio_mp3_password_wrong_user_returns_403(private_transcript, client):
"""Password backend: valid token for a different user returns 403."""
from reflector.auth.auth_password import UserInfo
from reflector.settings import settings
token = _create_hs256_token("other-user-id", settings.SECRET_KEY)
with patch("reflector.auth.verify_raw_token") as mock_verify:
mock_verify.return_value = UserInfo(
sub="other-user-id", email="other@example.com"
)
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_audio_mp3_invalid_token_returns_401(private_transcript, client):
"""Garbage token string returns 401."""
with patch("reflector.auth.verify_raw_token") as mock_verify:
mock_verify.return_value = None
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3" "?token=not-a-real-token"
)
assert response.status_code == 401
# ---------------------------------------------------------------------------
# JWT/Authentik auth backend tests (?token= with RS256)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_audio_mp3_authentik_token_query_param(private_transcript, client):
"""Authentik backend: valid RS256 ?token= grants access to private audio."""
from reflector.auth.auth_password import UserInfo
private_key, _ = _generate_rsa_keypair()
token = _create_rs256_token("authentik-abc123", private_key, "test-audience")
with patch("reflector.auth.verify_raw_token") as mock_verify:
# Authentik flow maps authentik_uid -> internal user id
mock_verify.return_value = UserInfo(
sub=OWNER_USER_ID, email="authentik-user@example.com"
)
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
@pytest.mark.asyncio
async def test_audio_mp3_authentik_expired_token_returns_401(
private_transcript, client
):
"""Authentik backend: expired RS256 ?token= returns 401."""
private_key, _ = _generate_rsa_keypair()
expired_token = _create_rs256_token(
"authentik-abc123", private_key, "test-audience", expired=True
)
with patch("reflector.auth.verify_raw_token") as mock_verify:
mock_verify.side_effect = jwt.ExpiredSignatureError("token expired")
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3" f"?token={expired_token}"
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_audio_mp3_authentik_wrong_user_returns_403(private_transcript, client):
"""Authentik backend: valid RS256 token for different user returns 403."""
from reflector.auth.auth_password import UserInfo
private_key, _ = _generate_rsa_keypair()
token = _create_rs256_token("authentik-other", private_key, "test-audience")
with patch("reflector.auth.verify_raw_token") as mock_verify:
mock_verify.return_value = UserInfo(
sub="different-user-id", email="other@example.com"
)
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
)
assert response.status_code == 403
# ---------------------------------------------------------------------------
# _generate_local_audio_link produces HS256 tokens — must be verifiable
# by any auth backend
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_local_audio_link_token_works_with_authentik_backend(
private_transcript, client
):
"""_generate_local_audio_link creates an HS256 token via create_access_token.
When the Authentik (RS256) auth backend is active, verify_raw_token uses
JWTAuth which expects RS256 + public key. The HS256 token created by
_generate_local_audio_link will fail verification, returning 401.
This test documents the bug: the internal audio URL generated for the
diarization pipeline is unusable under the JWT auth backend.
"""
from urllib.parse import parse_qs, urlparse
# Generate the internal audio link (uses create_access_token → HS256)
url = private_transcript._generate_local_audio_link()
parsed = urlparse(url)
token = parse_qs(parsed.query)["token"][0]
# Simulate what happens when the JWT/Authentik backend tries to verify
# this HS256 token: JWTAuth.verify_token expects RS256, so it raises.
with patch("reflector.auth.verify_raw_token") as mock_verify:
mock_verify.side_effect = jwt.exceptions.InvalidAlgorithmError(
"the specified alg value is not allowed"
)
response = await client.get(
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
)
# BUG: this should be 200 (the token was created by our own server),
# but the Authentik backend rejects it because it's HS256, not RS256.
assert response.status_code == 200