mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: self-hosted gpu api (#636)
* Self-hosted gpu api * Refactor self-hosted api * Rename model api tests * Use lifespan instead of startup event * Fix self hosted imports * Add newlines * Add response models * Move gpu dir to the root * Add project description * Refactor lifespan * Update env var names for model api tests * Preload diarizarion service * Refactor uploaded file paths
This commit is contained in:
63
server/tests/test_model_api_diarization.py
Normal file
63
server/tests/test_model_api_diarization.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Tests for diarization Model API endpoint (self-hosted service compatible shape).
|
||||
|
||||
Marked with the "model_api" marker and skipped unless DIARIZATION_URL is provided.
|
||||
|
||||
Run with for local self-hosted server:
|
||||
DIARIZATION_API_KEY=dev-key \
|
||||
DIARIZATION_URL=http://localhost:8000 \
|
||||
uv run -m pytest -m model_api --no-cov tests/test_model_api_diarization.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Public test audio file hosted on S3 specifically for reflector pytests
|
||||
TEST_AUDIO_URL = (
|
||||
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||
)
|
||||
|
||||
|
||||
def get_modal_diarization_url():
|
||||
url = os.environ.get("DIARIZATION_URL")
|
||||
if not url:
|
||||
pytest.skip(
|
||||
"DIARIZATION_URL environment variable is required for Model API tests"
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
def get_auth_headers():
|
||||
api_key = os.environ.get("DIARIZATION_API_KEY") or os.environ.get(
|
||||
"REFLECTOR_GPU_APIKEY"
|
||||
)
|
||||
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
|
||||
|
||||
@pytest.mark.model_api
|
||||
class TestModelAPIDiarization:
|
||||
def test_diarize_from_url(self):
|
||||
url = get_modal_diarization_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{url}/diarize",
|
||||
params={"audio_file_url": TEST_AUDIO_URL, "timestamp": 0.0},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
assert "diarization" in result
|
||||
assert isinstance(result["diarization"], list)
|
||||
assert len(result["diarization"]) > 0
|
||||
|
||||
for seg in result["diarization"]:
|
||||
assert "start" in seg and "end" in seg and "speaker" in seg
|
||||
assert isinstance(seg["start"], (int, float))
|
||||
assert isinstance(seg["end"], (int, float))
|
||||
assert seg["start"] <= seg["end"]
|
||||
@@ -1,21 +1,21 @@
|
||||
"""
|
||||
Tests for GPU Modal transcription endpoints.
|
||||
Tests for transcription Model API endpoints.
|
||||
|
||||
These tests are marked with the "gpu-modal" group and will not run by default.
|
||||
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
|
||||
These tests are marked with the "model_api" group and will not run by default.
|
||||
Run them with: pytest -m model_api tests/test_model_api_transcript.py
|
||||
|
||||
Required environment variables:
|
||||
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
|
||||
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
|
||||
- TRANSCRIPT_URL: URL to the Model API endpoint (required)
|
||||
- TRANSCRIPT_API_KEY: API key for authentication (optional)
|
||||
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
|
||||
|
||||
Example with pytest (override default addopts to run ONLY gpu_modal tests):
|
||||
Example with pytest (override default addopts to run ONLY model_api tests):
|
||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
|
||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
|
||||
TRANSCRIPT_API_KEY=your-api-key \
|
||||
uv run -m pytest -m model_api --no-cov tests/test_model_api_transcript.py
|
||||
|
||||
# Or with completely clean options:
|
||||
uv run -m pytest -m gpu_modal -o addopts="" tests/
|
||||
uv run -m pytest -m model_api -o addopts="" tests/
|
||||
|
||||
Running Modal locally for testing:
|
||||
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||
@@ -40,14 +40,16 @@ def get_modal_transcript_url():
|
||||
url = os.environ.get("TRANSCRIPT_URL")
|
||||
if not url:
|
||||
pytest.skip(
|
||||
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
|
||||
"TRANSCRIPT_URL environment variable is required for Model API tests"
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
def get_auth_headers():
|
||||
"""Get authentication headers if API key is available."""
|
||||
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
|
||||
api_key = os.environ.get("TRANSCRIPT_API_KEY") or os.environ.get(
|
||||
"REFLECTOR_GPU_APIKEY"
|
||||
)
|
||||
if api_key:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
return {}
|
||||
@@ -58,8 +60,8 @@ def get_model_name():
|
||||
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
|
||||
|
||||
|
||||
@pytest.mark.gpu_modal
|
||||
class TestGPUModalTranscript:
|
||||
@pytest.mark.model_api
|
||||
class TestModelAPITranscript:
|
||||
"""Test suite for GPU Modal transcription endpoints."""
|
||||
|
||||
def test_transcriptions_from_url(self):
|
||||
56
server/tests/test_model_api_translation.py
Normal file
56
server/tests/test_model_api_translation.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Tests for translation Model API endpoint (self-hosted service compatible shape).
|
||||
|
||||
Marked with the "model_api" marker and skipped unless TRANSLATION_URL is provided
|
||||
or we fallback to TRANSCRIPT_URL base (same host for self-hosted).
|
||||
|
||||
Run locally against self-hosted server:
|
||||
TRANSLATION_API_KEY=dev-key \
|
||||
TRANSLATION_URL=http://localhost:8000 \
|
||||
uv run -m pytest -m model_api --no-cov tests/test_model_api_translation.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
def get_translation_url():
|
||||
url = os.environ.get("TRANSLATION_URL") or os.environ.get("TRANSCRIPT_URL")
|
||||
if not url:
|
||||
pytest.skip(
|
||||
"TRANSLATION_URL or TRANSCRIPT_URL environment variable is required for Model API tests"
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
def get_auth_headers():
|
||||
api_key = os.environ.get("TRANSLATION_API_KEY") or os.environ.get(
|
||||
"REFLECTOR_GPU_APIKEY"
|
||||
)
|
||||
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
|
||||
|
||||
@pytest.mark.model_api
|
||||
class TestModelAPITranslation:
|
||||
def test_translate_text(self):
|
||||
url = get_translation_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{url}/translate",
|
||||
params={"text": "The meeting will start in five minutes."},
|
||||
json={"source_language": "en", "target_language": "fr"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "text" in data and isinstance(data["text"], dict)
|
||||
assert data["text"].get("en") == "The meeting will start in five minutes."
|
||||
assert isinstance(data["text"].get("fr", ""), str)
|
||||
assert len(data["text"]["fr"]) > 0
|
||||
assert data["text"]["fr"] == "La réunion commencera dans cinq minutes."
|
||||
Reference in New Issue
Block a user