mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-22 21:25:18 +00:00
feat: 3-mode selfhosted refactoring (--gpu, --cpu, --hosted) + audio token auth fallback (#896)
* fix: local processing instead of http server for cpu * add fallback token if service worker doesnt work * chore: rename processors to keep processor pattern up to date and allow other processors to be createed and used with env vars
This commit is contained in:
committed by
GitHub
parent
4235ab4293
commit
a682846645
450
server/tests/test_processors_cpu.py
Normal file
450
server/tests/test_processors_cpu.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Tests for in-process processor backends (--cpu mode).
|
||||
|
||||
All ML model calls are mocked — no actual model loading needed.
|
||||
Tests verify processor registration, wiring, error handling, and data flow.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummaryWithId,
|
||||
Transcript,
|
||||
Word,
|
||||
)
|
||||
|
||||
# ── Registration Tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_audio_diarization_pyannote_registers():
|
||||
"""Verify AudioDiarizationPyannoteProcessor registers with 'pyannote' backend."""
|
||||
# Importing the module triggers registration
|
||||
import reflector.processors.audio_diarization_pyannote # noqa: F401
|
||||
from reflector.processors.audio_diarization_auto import (
|
||||
AudioDiarizationAutoProcessor,
|
||||
)
|
||||
|
||||
assert "pyannote" in AudioDiarizationAutoProcessor._registry
|
||||
|
||||
|
||||
def test_file_diarization_pyannote_registers():
|
||||
"""Verify FileDiarizationPyannoteProcessor registers with 'pyannote' backend."""
|
||||
import reflector.processors.file_diarization_pyannote # noqa: F401
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
|
||||
assert "pyannote" in FileDiarizationAutoProcessor._registry
|
||||
|
||||
|
||||
def test_transcript_translator_marian_registers():
|
||||
"""Verify TranscriptTranslatorMarianProcessor registers with 'marian' backend."""
|
||||
import reflector.processors.transcript_translator_marian # noqa: F401
|
||||
from reflector.processors.transcript_translator_auto import (
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
|
||||
assert "marian" in TranscriptTranslatorAutoProcessor._registry
|
||||
|
||||
|
||||
def test_file_transcript_whisper_registers():
|
||||
"""Verify FileTranscriptWhisperProcessor registers with 'whisper' backend."""
|
||||
import reflector.processors.file_transcript_whisper # noqa: F401
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
|
||||
assert "whisper" in FileTranscriptAutoProcessor._registry
|
||||
|
||||
|
||||
# ── Audio Download Utility Tests ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_audio_to_temp_success():
|
||||
"""Verify download_audio_to_temp downloads to a temp file and returns path."""
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "audio/wav"}
|
||||
mock_response.iter_content.return_value = [b"fake audio data"]
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("reflector.processors._audio_download.requests.get") as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = await download_audio_to_temp("https://example.com/test.wav")
|
||||
|
||||
assert isinstance(result, Path)
|
||||
assert result.exists()
|
||||
assert result.read_bytes() == b"fake audio data"
|
||||
assert result.suffix == ".wav"
|
||||
|
||||
# Cleanup
|
||||
os.unlink(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_audio_to_temp_cleanup_on_error():
|
||||
"""Verify temp file is cleaned up when download fails mid-write."""
|
||||
from reflector.processors._audio_download import download_audio_to_temp
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers = {"content-type": "audio/wav"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
def fail_iter(*args, **kwargs):
|
||||
raise ConnectionError("Download interrupted")
|
||||
|
||||
mock_response.iter_content = fail_iter
|
||||
|
||||
with patch("reflector.processors._audio_download.requests.get") as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(ConnectionError, match="Download interrupted"):
|
||||
await download_audio_to_temp("https://example.com/test.wav")
|
||||
|
||||
|
||||
def test_detect_extension_from_url():
|
||||
"""Verify extension detection from URL path."""
|
||||
from reflector.processors._audio_download import _detect_extension
|
||||
|
||||
assert _detect_extension("https://example.com/test.wav", "") == ".wav"
|
||||
assert _detect_extension("https://example.com/test.mp3?signed=1", "") == ".mp3"
|
||||
assert _detect_extension("https://example.com/test.webm", "") == ".webm"
|
||||
|
||||
|
||||
def test_detect_extension_from_content_type():
|
||||
"""Verify extension detection from content-type header."""
|
||||
from reflector.processors._audio_download import _detect_extension
|
||||
|
||||
assert _detect_extension("https://s3.aws/uuid", "audio/mpeg") == ".mp3"
|
||||
assert _detect_extension("https://s3.aws/uuid", "audio/wav") == ".wav"
|
||||
assert _detect_extension("https://s3.aws/uuid", "audio/webm") == ".webm"
|
||||
|
||||
|
||||
def test_detect_extension_fallback():
|
||||
"""Verify fallback extension when neither URL nor content-type is recognized."""
|
||||
from reflector.processors._audio_download import _detect_extension
|
||||
|
||||
assert (
|
||||
_detect_extension("https://s3.aws/uuid", "application/octet-stream") == ".audio"
|
||||
)
|
||||
|
||||
|
||||
# ── Audio Diarization Pyannote Processor Tests ──────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_diarization_pyannote_diarize():
|
||||
"""Verify pyannote audio diarization downloads, diarizes, and cleans up."""
|
||||
from reflector.processors.audio_diarization_pyannote import (
|
||||
AudioDiarizationPyannoteProcessor,
|
||||
)
|
||||
|
||||
mock_diarization_result = {
|
||||
"diarization": [
|
||||
{"start": 0.0, "end": 2.5, "speaker": 0},
|
||||
{"start": 2.5, "end": 5.0, "speaker": 1},
|
||||
]
|
||||
}
|
||||
|
||||
# Create a temp file to simulate download
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.write(b"fake audio")
|
||||
tmp.close()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
processor = AudioDiarizationPyannoteProcessor()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.audio_diarization_pyannote.download_audio_to_temp",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tmp_path,
|
||||
),
|
||||
patch(
|
||||
"reflector.processors.audio_diarization_pyannote.diarization_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.diarize_file.return_value = mock_diarization_result
|
||||
|
||||
data = AudioDiarizationInput(
|
||||
audio_url="https://example.com/test.wav",
|
||||
topics=[
|
||||
TitleSummaryWithId(
|
||||
id="topic-1",
|
||||
title="Test Topic",
|
||||
summary="A test topic",
|
||||
timestamp=0.0,
|
||||
duration=5.0,
|
||||
transcript=Transcript(
|
||||
words=[Word(text="hello", start=0.0, end=1.0)]
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
result = await processor._diarize(data)
|
||||
|
||||
assert result == mock_diarization_result["diarization"]
|
||||
mock_svc.diarize_file.assert_called_once()
|
||||
|
||||
|
||||
# ── File Diarization Pyannote Processor Tests ───────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_diarization_pyannote_diarize():
|
||||
"""Verify pyannote file diarization returns FileDiarizationOutput."""
|
||||
from reflector.processors.file_diarization_pyannote import (
|
||||
FileDiarizationPyannoteProcessor,
|
||||
)
|
||||
|
||||
mock_diarization_result = {
|
||||
"diarization": [
|
||||
{"start": 0.0, "end": 3.0, "speaker": 0},
|
||||
{"start": 3.0, "end": 6.0, "speaker": 1},
|
||||
]
|
||||
}
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.write(b"fake audio")
|
||||
tmp.close()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
processor = FileDiarizationPyannoteProcessor()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.download_audio_to_temp",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tmp_path,
|
||||
),
|
||||
patch(
|
||||
"reflector.processors.file_diarization_pyannote.diarization_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.diarize_file.return_value = mock_diarization_result
|
||||
|
||||
data = FileDiarizationInput(audio_url="https://example.com/test.wav")
|
||||
result = await processor._diarize(data)
|
||||
|
||||
assert isinstance(result, FileDiarizationOutput)
|
||||
assert len(result.diarization) == 2
|
||||
assert result.diarization[0]["start"] == 0.0
|
||||
assert result.diarization[1]["speaker"] == 1
|
||||
|
||||
|
||||
# ── Transcript Translator Marian Processor Tests ───────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_translator_marian_translate():
|
||||
"""Verify MarianMT translator calls service and extracts translation."""
|
||||
from reflector.processors.transcript_translator_marian import (
|
||||
TranscriptTranslatorMarianProcessor,
|
||||
)
|
||||
|
||||
mock_result = {"text": {"en": "Hello world", "fr": "Bonjour le monde"}}
|
||||
|
||||
processor = TranscriptTranslatorMarianProcessor()
|
||||
|
||||
def fake_get_pref(key, default=None):
|
||||
prefs = {"audio:source_language": "en", "audio:target_language": "fr"}
|
||||
return prefs.get(key, default)
|
||||
|
||||
with (
|
||||
patch.object(processor, "get_pref", side_effect=fake_get_pref),
|
||||
patch(
|
||||
"reflector.processors.transcript_translator_marian.translator_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.translate.return_value = mock_result
|
||||
|
||||
result = await processor._translate("Hello world")
|
||||
|
||||
assert result == "Bonjour le monde"
|
||||
mock_svc.translate.assert_called_once_with("Hello world", "en", "fr")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_translator_marian_no_translation():
|
||||
"""Verify translator returns None when target language not in result."""
|
||||
from reflector.processors.transcript_translator_marian import (
|
||||
TranscriptTranslatorMarianProcessor,
|
||||
)
|
||||
|
||||
mock_result = {"text": {"en": "Hello world"}}
|
||||
|
||||
processor = TranscriptTranslatorMarianProcessor()
|
||||
|
||||
def fake_get_pref(key, default=None):
|
||||
prefs = {"audio:source_language": "en", "audio:target_language": "fr"}
|
||||
return prefs.get(key, default)
|
||||
|
||||
with (
|
||||
patch.object(processor, "get_pref", side_effect=fake_get_pref),
|
||||
patch(
|
||||
"reflector.processors.transcript_translator_marian.translator_service"
|
||||
) as mock_svc,
|
||||
):
|
||||
mock_svc.translate.return_value = mock_result
|
||||
|
||||
result = await processor._translate("Hello world")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── File Transcript Whisper Processor Tests ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_transcript_whisper_transcript():
|
||||
"""Verify whisper file processor downloads, transcribes, and returns Transcript."""
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_whisper import (
|
||||
FileTranscriptWhisperProcessor,
|
||||
)
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.write(b"fake audio")
|
||||
tmp.close()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
processor = FileTranscriptWhisperProcessor()
|
||||
|
||||
# Mock the blocking transcription method
|
||||
mock_transcript = Transcript(
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=0.5),
|
||||
Word(text=" world", start=0.5, end=1.0),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.processors.file_transcript_whisper.download_audio_to_temp",
|
||||
new_callable=AsyncMock,
|
||||
return_value=tmp_path,
|
||||
),
|
||||
patch.object(
|
||||
processor,
|
||||
"_transcribe_file_blocking",
|
||||
return_value=mock_transcript,
|
||||
),
|
||||
):
|
||||
data = FileTranscriptInput(
|
||||
audio_url="https://example.com/test.wav", language="en"
|
||||
)
|
||||
result = await processor._transcript(data)
|
||||
|
||||
assert isinstance(result, Transcript)
|
||||
assert len(result.words) == 2
|
||||
assert result.words[0].text == "Hello"
|
||||
|
||||
|
||||
# ── VAD Helper Tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_enforce_word_timing_constraints():
|
||||
"""Verify word timing enforcement prevents overlapping times."""
|
||||
from reflector.processors.file_transcript_whisper import (
|
||||
_enforce_word_timing_constraints,
|
||||
)
|
||||
|
||||
words = [
|
||||
{"word": "hello", "start": 0.0, "end": 1.5},
|
||||
{"word": "world", "start": 1.0, "end": 2.0}, # overlaps with previous
|
||||
{"word": "test", "start": 2.0, "end": 3.0},
|
||||
]
|
||||
|
||||
result = _enforce_word_timing_constraints(words)
|
||||
|
||||
assert result[0]["end"] == 1.0 # Clamped to next word's start
|
||||
assert result[1]["end"] == 2.0 # Clamped to next word's start
|
||||
assert result[2]["end"] == 3.0 # Last word unchanged
|
||||
|
||||
|
||||
def test_enforce_word_timing_constraints_empty():
|
||||
"""Verify timing enforcement handles empty and single-word lists."""
|
||||
from reflector.processors.file_transcript_whisper import (
|
||||
_enforce_word_timing_constraints,
|
||||
)
|
||||
|
||||
assert _enforce_word_timing_constraints([]) == []
|
||||
assert _enforce_word_timing_constraints([{"word": "a", "start": 0, "end": 1}]) == [
|
||||
{"word": "a", "start": 0, "end": 1}
|
||||
]
|
||||
|
||||
|
||||
def test_pad_audio_short():
|
||||
"""Verify short audio gets padded with silence."""
|
||||
import numpy as np
|
||||
|
||||
from reflector.processors.file_transcript_whisper import _pad_audio
|
||||
|
||||
short_audio = np.zeros(100, dtype=np.float32) # Very short
|
||||
result = _pad_audio(short_audio, sample_rate=16000)
|
||||
|
||||
# Should be padded to at least silence_padding duration
|
||||
assert len(result) > len(short_audio)
|
||||
|
||||
|
||||
def test_pad_audio_long():
|
||||
"""Verify long audio is not padded."""
|
||||
import numpy as np
|
||||
|
||||
from reflector.processors.file_transcript_whisper import _pad_audio
|
||||
|
||||
long_audio = np.zeros(32000, dtype=np.float32) # 2 seconds
|
||||
result = _pad_audio(long_audio, sample_rate=16000)
|
||||
|
||||
assert len(result) == len(long_audio)
|
||||
|
||||
|
||||
# ── Translator Service Tests ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_translator_service_resolve_model():
|
||||
"""Verify model resolution for known and unknown language pairs."""
|
||||
from reflector.processors._marian_translator_service import MarianTranslatorService
|
||||
|
||||
svc = MarianTranslatorService()
|
||||
|
||||
assert svc._resolve_model_name("en", "fr") == "Helsinki-NLP/opus-mt-en-fr"
|
||||
assert svc._resolve_model_name("es", "en") == "Helsinki-NLP/opus-mt-es-en"
|
||||
assert svc._resolve_model_name("en", "de") == "Helsinki-NLP/opus-mt-en-de"
|
||||
# Unknown pair falls back to en->fr
|
||||
assert svc._resolve_model_name("ja", "ko") == "Helsinki-NLP/opus-mt-en-fr"
|
||||
|
||||
|
||||
# ── Diarization Service Tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_diarization_service_singleton():
|
||||
"""Verify diarization_service is a module-level singleton."""
|
||||
from reflector.processors._pyannote_diarization_service import (
|
||||
PyannoteDiarizationService,
|
||||
diarization_service,
|
||||
)
|
||||
|
||||
assert isinstance(diarization_service, PyannoteDiarizationService)
|
||||
assert diarization_service._pipeline is None # Not loaded until first use
|
||||
|
||||
|
||||
def test_translator_service_singleton():
|
||||
"""Verify translator_service is a module-level singleton."""
|
||||
from reflector.processors._marian_translator_service import (
|
||||
MarianTranslatorService,
|
||||
translator_service,
|
||||
)
|
||||
|
||||
assert isinstance(translator_service, MarianTranslatorService)
|
||||
assert translator_service._pipeline is None # Not loaded until first use
|
||||
327
server/tests/test_transcripts_audio_token_auth.py
Normal file
327
server/tests/test_transcripts_audio_token_auth.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Tests for audio mp3 endpoint token query-param authentication.
|
||||
|
||||
Covers both password (HS256) and JWT/Authentik (RS256) auth backends,
|
||||
verifying that private transcripts can be accessed via ?token= query param.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
OWNER_USER_ID = "test-owner-user-id"
|
||||
|
||||
|
||||
def _create_hs256_token(user_id: str, secret: str, expired: bool = False) -> str:
|
||||
"""Create an HS256 JWT like the password auth backend does."""
|
||||
delta = timedelta(minutes=-5) if expired else timedelta(hours=24)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": "test@example.com",
|
||||
"exp": datetime.now(timezone.utc) + delta,
|
||||
}
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
|
||||
def _generate_rsa_keypair():
|
||||
"""Generate a fresh RSA keypair for tests."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
return private_key, public_pem.decode()
|
||||
|
||||
|
||||
def _create_rs256_token(
|
||||
authentik_uid: str,
|
||||
private_key,
|
||||
audience: str,
|
||||
expired: bool = False,
|
||||
) -> str:
|
||||
"""Create an RS256 JWT like Authentik would issue."""
|
||||
delta = timedelta(minutes=-5) if expired else timedelta(hours=1)
|
||||
payload = {
|
||||
"sub": authentik_uid,
|
||||
"email": "authentik-user@example.com",
|
||||
"aud": audience,
|
||||
"exp": datetime.now(timezone.utc) + delta,
|
||||
}
|
||||
return jwt.encode(payload, private_key, algorithm="RS256")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def private_transcript(tmpdir):
|
||||
"""Create a private transcript owned by OWNER_USER_ID with an mp3 file.
|
||||
|
||||
Created directly via the controller (not HTTP) so no auth override
|
||||
leaks into the test scope.
|
||||
"""
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
transcript = await transcripts_controller.add(
|
||||
"Private audio test",
|
||||
source_kind=SourceKind.FILE,
|
||||
user_id=OWNER_USER_ID,
|
||||
share_mode="private",
|
||||
)
|
||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||
|
||||
# Copy a real mp3 to the expected location
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
mp3_source = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(mp3_source, audio_filename)
|
||||
|
||||
yield transcript
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core access control tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_private_no_auth_returns_403(private_transcript, client):
|
||||
"""Without auth, accessing a private transcript's audio returns 403."""
|
||||
response = await client.get(f"/transcripts/{private_transcript.id}/audio/mp3")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_with_bearer_header(private_transcript, client):
|
||||
"""Owner accessing audio via Authorization header works."""
|
||||
from reflector.app import app
|
||||
from reflector.auth import current_user_optional
|
||||
|
||||
# Temporarily override to simulate the owner being authenticated
|
||||
app.dependency_overrides[current_user_optional] = lambda: {
|
||||
"sub": OWNER_USER_ID,
|
||||
"email": "test@example.com",
|
||||
}
|
||||
try:
|
||||
response = await client.get(f"/transcripts/{private_transcript.id}/audio/mp3")
|
||||
finally:
|
||||
del app.dependency_overrides[current_user_optional]
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_public_transcript_no_auth_ok(tmpdir, client):
|
||||
"""Public transcripts are accessible without any auth."""
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
transcript = await transcripts_controller.add(
|
||||
"Public audio test",
|
||||
source_kind=SourceKind.FILE,
|
||||
user_id=OWNER_USER_ID,
|
||||
share_mode="public",
|
||||
)
|
||||
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
mp3_source = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(mp3_source, audio_filename)
|
||||
|
||||
response = await client.get(f"/transcripts/{transcript.id}/audio/mp3")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password auth backend tests (?token= with HS256)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_password_token_query_param(private_transcript, client):
|
||||
"""Password backend: valid HS256 ?token= grants access to private audio."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
from reflector.settings import settings
|
||||
|
||||
token = _create_hs256_token(OWNER_USER_ID, settings.SECRET_KEY)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = UserInfo(sub=OWNER_USER_ID, email="test@example.com")
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_password_expired_token_returns_401(private_transcript, client):
|
||||
"""Password backend: expired HS256 ?token= returns 401."""
|
||||
from reflector.settings import settings
|
||||
|
||||
expired_token = _create_hs256_token(
|
||||
OWNER_USER_ID, settings.SECRET_KEY, expired=True
|
||||
)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.side_effect = jwt.ExpiredSignatureError("token expired")
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3" f"?token={expired_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_password_wrong_user_returns_403(private_transcript, client):
|
||||
"""Password backend: valid token for a different user returns 403."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
from reflector.settings import settings
|
||||
|
||||
token = _create_hs256_token("other-user-id", settings.SECRET_KEY)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = UserInfo(
|
||||
sub="other-user-id", email="other@example.com"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_invalid_token_returns_401(private_transcript, client):
|
||||
"""Garbage token string returns 401."""
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = None
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3" "?token=not-a-real-token"
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT/Authentik auth backend tests (?token= with RS256)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_authentik_token_query_param(private_transcript, client):
|
||||
"""Authentik backend: valid RS256 ?token= grants access to private audio."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
|
||||
private_key, _ = _generate_rsa_keypair()
|
||||
token = _create_rs256_token("authentik-abc123", private_key, "test-audience")
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
# Authentik flow maps authentik_uid -> internal user id
|
||||
mock_verify.return_value = UserInfo(
|
||||
sub=OWNER_USER_ID, email="authentik-user@example.com"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_authentik_expired_token_returns_401(
|
||||
private_transcript, client
|
||||
):
|
||||
"""Authentik backend: expired RS256 ?token= returns 401."""
|
||||
private_key, _ = _generate_rsa_keypair()
|
||||
expired_token = _create_rs256_token(
|
||||
"authentik-abc123", private_key, "test-audience", expired=True
|
||||
)
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.side_effect = jwt.ExpiredSignatureError("token expired")
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3" f"?token={expired_token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_mp3_authentik_wrong_user_returns_403(private_transcript, client):
|
||||
"""Authentik backend: valid RS256 token for different user returns 403."""
|
||||
from reflector.auth.auth_password import UserInfo
|
||||
|
||||
private_key, _ = _generate_rsa_keypair()
|
||||
token = _create_rs256_token("authentik-other", private_key, "test-audience")
|
||||
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.return_value = UserInfo(
|
||||
sub="different-user-id", email="other@example.com"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_local_audio_link produces HS256 tokens — must be verifiable
|
||||
# by any auth backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_audio_link_token_works_with_authentik_backend(
|
||||
private_transcript, client
|
||||
):
|
||||
"""_generate_local_audio_link creates an HS256 token via create_access_token.
|
||||
|
||||
When the Authentik (RS256) auth backend is active, verify_raw_token uses
|
||||
JWTAuth which expects RS256 + public key. The HS256 token created by
|
||||
_generate_local_audio_link will fail verification, returning 401.
|
||||
|
||||
This test documents the bug: the internal audio URL generated for the
|
||||
diarization pipeline is unusable under the JWT auth backend.
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
# Generate the internal audio link (uses create_access_token → HS256)
|
||||
url = private_transcript._generate_local_audio_link()
|
||||
parsed = urlparse(url)
|
||||
token = parse_qs(parsed.query)["token"][0]
|
||||
|
||||
# Simulate what happens when the JWT/Authentik backend tries to verify
|
||||
# this HS256 token: JWTAuth.verify_token expects RS256, so it raises.
|
||||
with patch("reflector.auth.verify_raw_token") as mock_verify:
|
||||
mock_verify.side_effect = jwt.exceptions.InvalidAlgorithmError(
|
||||
"the specified alg value is not allowed"
|
||||
)
|
||||
response = await client.get(
|
||||
f"/transcripts/{private_transcript.id}/audio/mp3?token={token}"
|
||||
)
|
||||
|
||||
# BUG: this should be 200 (the token was created by our own server),
|
||||
# but the Authentik backend rejects it because it's HS256, not RS256.
|
||||
assert response.status_code == 200
|
||||
Reference in New Issue
Block a user