mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
feat: self-hosted gpu api (#636)
* Self-hosted gpu api * Refactor self-hosted api * Rename model api tests * Use lifespan instead of startup event * Fix self hosted imports * Add newlines * Add response models * Move gpu dir to the root * Add project description * Refactor lifespan * Update env var names for model api tests * Preload diarizarion service * Refactor uploaded file paths
This commit is contained in:
33
gpu/modal_deployments/.gitignore
vendored
Normal file
33
gpu/modal_deployments/.gitignore
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# OS / Editor
|
||||||
|
.DS_Store
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Env and secrets
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.env
|
||||||
|
*.secret
|
||||||
|
|
||||||
|
# Build / dist
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
.eggs/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# Coverage / test
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage*
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# Modal local state (if any)
|
||||||
|
modal_mounts/
|
||||||
|
.modal_cache/
|
||||||
2
gpu/self_hosted/.env.example
Normal file
2
gpu/self_hosted/.env.example
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
REFLECTOR_GPU_APIKEY=
|
||||||
|
HF_TOKEN=
|
||||||
38
gpu/self_hosted/.gitignore
vendored
Normal file
38
gpu/self_hosted/.gitignore
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
cache/
|
||||||
|
|
||||||
|
# OS / Editor
|
||||||
|
.DS_Store
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Env and secrets
|
||||||
|
.env
|
||||||
|
*.env
|
||||||
|
*.secret
|
||||||
|
HF_TOKEN
|
||||||
|
REFLECTOR_GPU_APIKEY
|
||||||
|
|
||||||
|
# Virtual env / uv
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
uv/
|
||||||
|
|
||||||
|
# Build / dist
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
.eggs/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# Coverage / test
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage*
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
46
gpu/self_hosted/Dockerfile
Normal file
46
gpu/self_hosted/Dockerfile
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
ENV PYTHONUNBUFFERED=1 \
|
||||||
|
UV_LINK_MODE=copy \
|
||||||
|
UV_NO_CACHE=1
|
||||||
|
|
||||||
|
WORKDIR /tmp
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y \
|
||||||
|
ffmpeg \
|
||||||
|
curl \
|
||||||
|
ca-certificates \
|
||||||
|
gnupg \
|
||||||
|
wget \
|
||||||
|
&& apt-get clean
|
||||||
|
# Add NVIDIA CUDA repo for Debian 12 (bookworm) and install cuDNN 9 for CUDA 12
|
||||||
|
ADD https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb /cuda-keyring.deb
|
||||||
|
RUN dpkg -i /cuda-keyring.deb \
|
||||||
|
&& rm /cuda-keyring.deb \
|
||||||
|
&& apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
cuda-cudart-12-6 \
|
||||||
|
libcublas-12-6 \
|
||||||
|
libcudnn9-cuda-12 \
|
||||||
|
libcudnn9-dev-cuda-12 \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
||||||
|
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
||||||
|
ENV PATH="/root/.local/bin/:$PATH"
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
|
RUN mkdir -p /app
|
||||||
|
WORKDIR /app
|
||||||
|
COPY pyproject.toml uv.lock /app/
|
||||||
|
|
||||||
|
|
||||||
|
COPY ./app /app/app
|
||||||
|
COPY ./main.py /app/
|
||||||
|
COPY ./runserver.sh /app/
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["sh", "/app/runserver.sh"]
|
||||||
|
|
||||||
|
|
||||||
73
gpu/self_hosted/README.md
Normal file
73
gpu/self_hosted/README.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Self-hosted Model API
|
||||||
|
|
||||||
|
Run transcription, translation, and diarization services compatible with Reflector's GPU Model API. Works on CPU or GPU.
|
||||||
|
|
||||||
|
Environment variables
|
||||||
|
|
||||||
|
- REFLECTOR_GPU_APIKEY: Optional Bearer token. If unset, auth is disabled.
|
||||||
|
- HF_TOKEN: Optional. Required for diarization to download pyannote pipelines
|
||||||
|
|
||||||
|
Requirements
|
||||||
|
|
||||||
|
- FFmpeg must be installed and on PATH (used for URL-based and segmented transcription)
|
||||||
|
- Python 3.12+
|
||||||
|
- NVIDIA GPU optional. If available, it will be used automatically
|
||||||
|
|
||||||
|
Local run
|
||||||
|
Set env vars in self_hosted/.env file
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
Authentication
|
||||||
|
|
||||||
|
- If REFLECTOR_GPU_APIKEY is set, include header: Authorization: Bearer <key>
|
||||||
|
|
||||||
|
Endpoints
|
||||||
|
|
||||||
|
- POST /v1/audio/transcriptions
|
||||||
|
|
||||||
|
- multipart/form-data
|
||||||
|
- fields: file (single file) OR files[] (multiple files), language, batch (true/false)
|
||||||
|
- response: single { text, words, filename } or { results: [ ... ] }
|
||||||
|
|
||||||
|
- POST /v1/audio/transcriptions-from-url
|
||||||
|
|
||||||
|
- application/json
|
||||||
|
- body: { audio_file_url, language, timestamp_offset }
|
||||||
|
- response: { text, words }
|
||||||
|
|
||||||
|
- POST /translate
|
||||||
|
|
||||||
|
- text: query parameter
|
||||||
|
- body (application/json): { source_language, target_language }
|
||||||
|
- response: { text: { <src>: original, <tgt>: translated } }
|
||||||
|
|
||||||
|
- POST /diarize
|
||||||
|
- query parameters: audio_file_url, timestamp (optional)
|
||||||
|
- requires HF_TOKEN to be set (for pyannote)
|
||||||
|
- response: { diarization: [ { start, end, speaker } ] }
|
||||||
|
|
||||||
|
OpenAPI docs
|
||||||
|
|
||||||
|
- Visit /docs when the server is running
|
||||||
|
|
||||||
|
Docker
|
||||||
|
|
||||||
|
- Not yet provided in this directory. A Dockerfile will be added later. For now, use Local run above
|
||||||
|
|
||||||
|
Conformance tests
|
||||||
|
|
||||||
|
# From this directory
|
||||||
|
|
||||||
|
TRANSCRIPT_URL=http://localhost:8000 \
|
||||||
|
TRANSCRIPT_API_KEY=dev-key \
|
||||||
|
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_transcript.py
|
||||||
|
|
||||||
|
TRANSLATION_URL=http://localhost:8000 \
|
||||||
|
TRANSLATION_API_KEY=dev-key \
|
||||||
|
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_translation.py
|
||||||
|
|
||||||
|
DIARIZATION_URL=http://localhost:8000 \
|
||||||
|
DIARIZATION_API_KEY=dev-key \
|
||||||
|
uv run -m pytest -m model_api --no-cov ../../server/tests/test_model_api_diarization.py
|
||||||
19
gpu/self_hosted/app/auth.py
Normal file
19
gpu/self_hosted/app/auth.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
|
||||||
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||||
|
required_key = os.environ.get("REFLECTOR_GPU_APIKEY")
|
||||||
|
if not required_key:
|
||||||
|
return
|
||||||
|
if apikey == required_key:
|
||||||
|
return
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
12
gpu/self_hosted/app/config.py
Normal file
12
gpu/self_hosted/app/config.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
VAD_CONFIG = {
|
||||||
|
"batch_max_duration": 30.0,
|
||||||
|
"silence_padding": 0.5,
|
||||||
|
"window_size": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
# App-level paths
|
||||||
|
UPLOADS_PATH = Path("/tmp/whisper-uploads")
|
||||||
30
gpu/self_hosted/app/factory.py
Normal file
30
gpu/self_hosted/app/factory.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from .routers.diarization import router as diarization_router
|
||||||
|
from .routers.transcription import router as transcription_router
|
||||||
|
from .routers.translation import router as translation_router
|
||||||
|
from .services.transcriber import WhisperService
|
||||||
|
from .services.diarizer import PyannoteDiarizationService
|
||||||
|
from .utils import ensure_dirs
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
ensure_dirs()
|
||||||
|
whisper_service = WhisperService()
|
||||||
|
whisper_service.load()
|
||||||
|
app.state.whisper = whisper_service
|
||||||
|
diarization_service = PyannoteDiarizationService()
|
||||||
|
diarization_service.load()
|
||||||
|
app.state.diarizer = diarization_service
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(transcription_router)
|
||||||
|
app.include_router(translation_router)
|
||||||
|
app.include_router(diarization_router)
|
||||||
|
return app
|
||||||
30
gpu/self_hosted/app/routers/diarization.py
Normal file
30
gpu/self_hosted/app/routers/diarization.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..auth import apikey_auth
|
||||||
|
from ..services.diarizer import PyannoteDiarizationService
|
||||||
|
from ..utils import download_audio_file
|
||||||
|
|
||||||
|
router = APIRouter(tags=["diarization"])
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizationSegment(BaseModel):
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
speaker: int
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizationResponse(BaseModel):
|
||||||
|
diarization: List[DiarizationSegment]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diarize", dependencies=[Depends(apikey_auth)], response_model=DiarizationResponse
|
||||||
|
)
|
||||||
|
def diarize(request: Request, audio_file_url: str, timestamp: float = 0.0):
|
||||||
|
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||||
|
file_path = str(file_path)
|
||||||
|
diarizer: PyannoteDiarizationService = request.app.state.diarizer
|
||||||
|
return diarizer.diarize_file(file_path, timestamp=timestamp)
|
||||||
109
gpu/self_hosted/app/routers/transcription.py
Normal file
109
gpu/self_hosted/app/routers/transcription.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pathlib import Path
|
||||||
|
from ..auth import apikey_auth
|
||||||
|
from ..config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||||
|
from ..services.transcriber import MODEL_NAME
|
||||||
|
from ..utils import cleanup_uploaded_files, download_audio_file
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/v1/audio", tags=["transcription"])
|
||||||
|
|
||||||
|
|
||||||
|
class WordTiming(BaseModel):
|
||||||
|
word: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptResult(BaseModel):
|
||||||
|
text: str
|
||||||
|
words: list[WordTiming]
|
||||||
|
filename: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptBatchResponse(BaseModel):
|
||||||
|
results: list[TranscriptResult]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/transcriptions",
|
||||||
|
dependencies=[Depends(apikey_auth)],
|
||||||
|
response_model=Union[TranscriptResult, TranscriptBatchResponse],
|
||||||
|
)
|
||||||
|
def transcribe(
|
||||||
|
request: Request,
|
||||||
|
file: UploadFile = None,
|
||||||
|
files: list[UploadFile] | None = None,
|
||||||
|
model: str = Form(MODEL_NAME),
|
||||||
|
language: str = Form("en"),
|
||||||
|
batch: bool = Form(False),
|
||||||
|
):
|
||||||
|
service = request.app.state.whisper
|
||||||
|
if not file and not files:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||||
|
)
|
||||||
|
if batch and not files:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Batch transcription requires 'files'"
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_files = [file] if file else files
|
||||||
|
|
||||||
|
uploaded_paths: list[Path] = []
|
||||||
|
with cleanup_uploaded_files(uploaded_paths):
|
||||||
|
for upload_file in upload_files:
|
||||||
|
audio_suffix = upload_file.filename.split(".")[-1].lower()
|
||||||
|
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=(
|
||||||
|
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||||
|
file_path = UPLOADS_PATH / unique_filename
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
content = upload_file.file.read()
|
||||||
|
f.write(content)
|
||||||
|
uploaded_paths.append(file_path)
|
||||||
|
|
||||||
|
if batch and len(upload_files) > 1:
|
||||||
|
results = []
|
||||||
|
for path in uploaded_paths:
|
||||||
|
result = service.transcribe_file(str(path), language=language)
|
||||||
|
result["filename"] = path.name
|
||||||
|
results.append(result)
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for path in uploaded_paths:
|
||||||
|
result = service.transcribe_file(str(path), language=language)
|
||||||
|
result["filename"] = path.name
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return {"results": results} if len(results) > 1 else results[0]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/transcriptions-from-url",
|
||||||
|
dependencies=[Depends(apikey_auth)],
|
||||||
|
response_model=TranscriptResult,
|
||||||
|
)
|
||||||
|
def transcribe_from_url(
|
||||||
|
request: Request,
|
||||||
|
audio_file_url: str = Body(..., description="URL of the audio file to transcribe"),
|
||||||
|
model: str = Body(MODEL_NAME),
|
||||||
|
language: str = Body("en"),
|
||||||
|
timestamp_offset: float = Body(0.0),
|
||||||
|
):
|
||||||
|
service = request.app.state.whisper
|
||||||
|
with download_audio_file(audio_file_url) as (file_path, _ext):
|
||||||
|
file_path = str(file_path)
|
||||||
|
result = service.transcribe_vad_url_segment(
|
||||||
|
file_path=file_path, timestamp_offset=timestamp_offset, language=language
|
||||||
|
)
|
||||||
|
return result
|
||||||
28
gpu/self_hosted/app/routers/translation.py
Normal file
28
gpu/self_hosted/app/routers/translation.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..auth import apikey_auth
|
||||||
|
from ..services.translator import TextTranslatorService
|
||||||
|
|
||||||
|
router = APIRouter(tags=["translation"])
|
||||||
|
|
||||||
|
translator = TextTranslatorService()
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationResponse(BaseModel):
|
||||||
|
text: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/translate",
|
||||||
|
dependencies=[Depends(apikey_auth)],
|
||||||
|
response_model=TranslationResponse,
|
||||||
|
)
|
||||||
|
def translate(
|
||||||
|
text: str,
|
||||||
|
source_language: str = Body("en"),
|
||||||
|
target_language: str = Body("fr"),
|
||||||
|
):
|
||||||
|
return translator.translate(text, source_language, target_language)
|
||||||
42
gpu/self_hosted/app/services/diarizer.py
Normal file
42
gpu/self_hosted/app/services/diarizer.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class PyannoteDiarizationService:
|
||||||
|
def __init__(self):
|
||||||
|
self._pipeline = None
|
||||||
|
self._device = "cpu"
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self._pipeline = Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization-3.1",
|
||||||
|
use_auth_token=os.environ.get("HF_TOKEN"),
|
||||||
|
)
|
||||||
|
self._pipeline.to(torch.device(self._device))
|
||||||
|
|
||||||
|
def diarize_file(self, file_path: str, timestamp: float = 0.0) -> dict:
|
||||||
|
if self._pipeline is None:
|
||||||
|
self.load()
|
||||||
|
waveform, sample_rate = torchaudio.load(file_path)
|
||||||
|
with self._lock:
|
||||||
|
diarization = self._pipeline(
|
||||||
|
{"waveform": waveform, "sample_rate": sample_rate}
|
||||||
|
)
|
||||||
|
words = []
|
||||||
|
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
words.append(
|
||||||
|
{
|
||||||
|
"start": round(timestamp + diarization_segment.start, 3),
|
||||||
|
"end": round(timestamp + diarization_segment.end, 3),
|
||||||
|
"speaker": int(speaker[-2:])
|
||||||
|
if speaker and speaker[-2:].isdigit()
|
||||||
|
else 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"diarization": words}
|
||||||
208
gpu/self_hosted/app/services/transcriber.py
Normal file
208
gpu/self_hosted/app/services/transcriber.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import faster_whisper
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from silero_vad import VADIterator, load_silero_vad
|
||||||
|
|
||||||
|
from ..config import SAMPLE_RATE, VAD_CONFIG
|
||||||
|
|
||||||
|
# Whisper configuration (service-local defaults)
|
||||||
|
MODEL_NAME = "large-v2"
|
||||||
|
# None delegates compute type to runtime: float16 on CUDA, int8 on CPU
|
||||||
|
MODEL_COMPUTE_TYPE = None
|
||||||
|
MODEL_NUM_WORKERS = 1
|
||||||
|
CACHE_PATH = os.path.join(os.path.expanduser("~"), ".cache", "reflector-whisper")
|
||||||
|
from ..utils import NoStdStreams
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperService:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
self.device = "cpu"
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
compute_type = MODEL_COMPUTE_TYPE or (
|
||||||
|
"float16" if self.device == "cuda" else "int8"
|
||||||
|
)
|
||||||
|
self.model = faster_whisper.WhisperModel(
|
||||||
|
MODEL_NAME,
|
||||||
|
device=self.device,
|
||||||
|
compute_type=compute_type,
|
||||||
|
num_workers=MODEL_NUM_WORKERS,
|
||||||
|
download_root=CACHE_PATH,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pad_audio(self, audio_array, sample_rate: int = SAMPLE_RATE):
|
||||||
|
audio_duration = len(audio_array) / sample_rate
|
||||||
|
if audio_duration < VAD_CONFIG["silence_padding"]:
|
||||||
|
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
||||||
|
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
return np.concatenate([audio_array, silence])
|
||||||
|
return audio_array
|
||||||
|
|
||||||
|
def enforce_word_timing_constraints(self, words: list[dict]) -> list[dict]:
|
||||||
|
if len(words) <= 1:
|
||||||
|
return words
|
||||||
|
enforced: list[dict] = []
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
current = dict(word)
|
||||||
|
if i < len(words) - 1:
|
||||||
|
next_start = words[i + 1]["start"]
|
||||||
|
if current["end"] > next_start:
|
||||||
|
current["end"] = next_start
|
||||||
|
enforced.append(current)
|
||||||
|
return enforced
|
||||||
|
|
||||||
|
def transcribe_file(self, file_path: str, language: str = "en") -> dict:
|
||||||
|
input_for_model: str | "object" = file_path
|
||||||
|
try:
|
||||||
|
audio_array, _sample_rate = librosa.load(
|
||||||
|
file_path, sr=SAMPLE_RATE, mono=True
|
||||||
|
)
|
||||||
|
if len(audio_array) / float(SAMPLE_RATE) < VAD_CONFIG["silence_padding"]:
|
||||||
|
input_for_model = self.pad_audio(audio_array, SAMPLE_RATE)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
with NoStdStreams():
|
||||||
|
segments, _ = self.model.transcribe(
|
||||||
|
input_for_model,
|
||||||
|
language=language,
|
||||||
|
beam_size=5,
|
||||||
|
word_timestamps=True,
|
||||||
|
vad_filter=True,
|
||||||
|
vad_parameters={"min_silence_duration_ms": 500},
|
||||||
|
)
|
||||||
|
|
||||||
|
segments = list(segments)
|
||||||
|
text = "".join(segment.text for segment in segments).strip()
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": word.word,
|
||||||
|
"start": round(float(word.start), 2),
|
||||||
|
"end": round(float(word.end), 2),
|
||||||
|
}
|
||||||
|
for segment in segments
|
||||||
|
for word in segment.words
|
||||||
|
]
|
||||||
|
words = self.enforce_word_timing_constraints(words)
|
||||||
|
return {"text": text, "words": words}
|
||||||
|
|
||||||
|
def transcribe_vad_url_segment(
|
||||||
|
self, file_path: str, timestamp_offset: float = 0.0, language: str = "en"
|
||||||
|
) -> dict:
|
||||||
|
def load_audio_via_ffmpeg(input_path: str, sample_rate: int) -> np.ndarray:
|
||||||
|
ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg"
|
||||||
|
cmd = [
|
||||||
|
ffmpeg_bin,
|
||||||
|
"-nostdin",
|
||||||
|
"-threads",
|
||||||
|
"1",
|
||||||
|
"-i",
|
||||||
|
input_path,
|
||||||
|
"-f",
|
||||||
|
"f32le",
|
||||||
|
"-acodec",
|
||||||
|
"pcm_f32le",
|
||||||
|
"-ac",
|
||||||
|
"1",
|
||||||
|
"-ar",
|
||||||
|
str(sample_rate),
|
||||||
|
"pipe:1",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"ffmpeg failed: {e}")
|
||||||
|
audio = np.frombuffer(proc.stdout, dtype=np.float32)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def vad_segments(
|
||||||
|
audio_array,
|
||||||
|
sample_rate: int = SAMPLE_RATE,
|
||||||
|
window_size: int = VAD_CONFIG["window_size"],
|
||||||
|
) -> Generator[tuple[float, float], None, None]:
|
||||||
|
vad_model = load_silero_vad(onnx=False)
|
||||||
|
iterator = VADIterator(vad_model, sampling_rate=sample_rate)
|
||||||
|
start = None
|
||||||
|
for i in range(0, len(audio_array), window_size):
|
||||||
|
chunk = audio_array[i : i + window_size]
|
||||||
|
if len(chunk) < window_size:
|
||||||
|
chunk = np.pad(
|
||||||
|
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||||
|
)
|
||||||
|
speech = iterator(chunk)
|
||||||
|
if not speech:
|
||||||
|
continue
|
||||||
|
if "start" in speech:
|
||||||
|
start = speech["start"]
|
||||||
|
continue
|
||||||
|
if "end" in speech and start is not None:
|
||||||
|
end = speech["end"]
|
||||||
|
yield (start / float(SAMPLE_RATE), end / float(SAMPLE_RATE))
|
||||||
|
start = None
|
||||||
|
iterator.reset_states()
|
||||||
|
|
||||||
|
audio_array = load_audio_via_ffmpeg(file_path, SAMPLE_RATE)
|
||||||
|
|
||||||
|
merged_batches: list[tuple[float, float]] = []
|
||||||
|
batch_start = None
|
||||||
|
batch_end = None
|
||||||
|
max_duration = VAD_CONFIG["batch_max_duration"]
|
||||||
|
for seg_start, seg_end in vad_segments(audio_array):
|
||||||
|
if batch_start is None:
|
||||||
|
batch_start, batch_end = seg_start, seg_end
|
||||||
|
continue
|
||||||
|
if seg_end - batch_start <= max_duration:
|
||||||
|
batch_end = seg_end
|
||||||
|
else:
|
||||||
|
merged_batches.append((batch_start, batch_end))
|
||||||
|
batch_start, batch_end = seg_start, seg_end
|
||||||
|
if batch_start is not None and batch_end is not None:
|
||||||
|
merged_batches.append((batch_start, batch_end))
|
||||||
|
|
||||||
|
all_text = []
|
||||||
|
all_words = []
|
||||||
|
for start_time, end_time in merged_batches:
|
||||||
|
s_idx = int(start_time * SAMPLE_RATE)
|
||||||
|
e_idx = int(end_time * SAMPLE_RATE)
|
||||||
|
segment = audio_array[s_idx:e_idx]
|
||||||
|
segment = self.pad_audio(segment, SAMPLE_RATE)
|
||||||
|
with self.lock:
|
||||||
|
segments, _ = self.model.transcribe(
|
||||||
|
segment,
|
||||||
|
language=language,
|
||||||
|
beam_size=5,
|
||||||
|
word_timestamps=True,
|
||||||
|
vad_filter=True,
|
||||||
|
vad_parameters={"min_silence_duration_ms": 500},
|
||||||
|
)
|
||||||
|
segments = list(segments)
|
||||||
|
text = "".join(seg.text for seg in segments).strip()
|
||||||
|
words = [
|
||||||
|
{
|
||||||
|
"word": w.word,
|
||||||
|
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
||||||
|
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
||||||
|
}
|
||||||
|
for seg in segments
|
||||||
|
for w in seg.words
|
||||||
|
]
|
||||||
|
if text:
|
||||||
|
all_text.append(text)
|
||||||
|
all_words.extend(words)
|
||||||
|
|
||||||
|
all_words = self.enforce_word_timing_constraints(all_words)
|
||||||
|
return {"text": " ".join(all_text), "words": all_words}
|
||||||
44
gpu/self_hosted/app/services/translator.py
Normal file
44
gpu/self_hosted/app/services/translator.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import threading
|
||||||
|
|
||||||
|
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class TextTranslatorService:
|
||||||
|
"""Simple text-to-text translator using HuggingFace MarianMT models.
|
||||||
|
|
||||||
|
This mirrors the modal translator API shape but uses text translation only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._pipeline = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def load(self, source_language: str = "en", target_language: str = "fr"):
|
||||||
|
# Pick a default MarianMT model pair if available; fall back to Helsinki-NLP en->fr
|
||||||
|
model_name = self._resolve_model_name(source_language, target_language)
|
||||||
|
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||||
|
model = MarianMTModel.from_pretrained(model_name)
|
||||||
|
self._pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
def _resolve_model_name(self, src: str, tgt: str) -> str:
|
||||||
|
# Minimal mapping; extend as needed
|
||||||
|
pair = (src.lower(), tgt.lower())
|
||||||
|
mapping = {
|
||||||
|
("en", "fr"): "Helsinki-NLP/opus-mt-en-fr",
|
||||||
|
("fr", "en"): "Helsinki-NLP/opus-mt-fr-en",
|
||||||
|
("en", "es"): "Helsinki-NLP/opus-mt-en-es",
|
||||||
|
("es", "en"): "Helsinki-NLP/opus-mt-es-en",
|
||||||
|
("en", "de"): "Helsinki-NLP/opus-mt-en-de",
|
||||||
|
("de", "en"): "Helsinki-NLP/opus-mt-de-en",
|
||||||
|
}
|
||||||
|
return mapping.get(pair, "Helsinki-NLP/opus-mt-en-fr")
|
||||||
|
|
||||||
|
def translate(self, text: str, source_language: str, target_language: str) -> dict:
|
||||||
|
if self._pipeline is None:
|
||||||
|
self.load(source_language, target_language)
|
||||||
|
with self._lock:
|
||||||
|
results = self._pipeline(
|
||||||
|
text, src_lang=source_language, tgt_lang=target_language
|
||||||
|
)
|
||||||
|
translated = results[0]["translation_text"] if results else ""
|
||||||
|
return {"text": {source_language: text, target_language: translated}}
|
||||||
107
gpu/self_hosted/app/utils.py
Normal file
107
gpu/self_hosted/app/utils.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Mapping
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from .config import SUPPORTED_FILE_EXTENSIONS, UPLOADS_PATH
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NoStdStreams:
|
||||||
|
def __init__(self):
|
||||||
|
self.devnull = open(os.devnull, "w")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||||
|
self._stdout.flush()
|
||||||
|
self._stderr.flush()
|
||||||
|
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||||
|
self.devnull.close()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_dirs():
|
||||||
|
UPLOADS_PATH.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_audio_format(url: str, headers: Mapping[str, str]) -> str:
|
||||||
|
url_path = urlparse(url).path
|
||||||
|
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||||
|
if url_path.lower().endswith(f".{ext}"):
|
||||||
|
return ext
|
||||||
|
|
||||||
|
content_type = headers.get("content-type", "").lower()
|
||||||
|
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||||
|
return "mp3"
|
||||||
|
if "audio/wav" in content_type:
|
||||||
|
return "wav"
|
||||||
|
if "audio/mp4" in content_type:
|
||||||
|
return "mp4"
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=(
|
||||||
|
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_audio_to_uploads(audio_file_url: str) -> tuple[Path, str]:
|
||||||
|
response = requests.head(audio_file_url, allow_redirects=True)
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||||
|
|
||||||
|
response = requests.get(audio_file_url, allow_redirects=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||||
|
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||||
|
file_path: Path = UPLOADS_PATH / unique_filename
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
return file_path, audio_suffix
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def download_audio_file(audio_file_url: str):
|
||||||
|
"""Download an audio file to UPLOADS_PATH and remove it after use.
|
||||||
|
|
||||||
|
Yields (file_path: Path, audio_suffix: str).
|
||||||
|
"""
|
||||||
|
file_path, audio_suffix = download_audio_to_uploads(audio_file_url)
|
||||||
|
try:
|
||||||
|
yield file_path, audio_suffix
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
file_path.unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error deleting temporary file %s: %s", file_path, e)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def cleanup_uploaded_files(file_paths: list[Path]):
|
||||||
|
"""Ensure provided file paths are removed after use.
|
||||||
|
|
||||||
|
The provided list can be populated inside the context; all present entries
|
||||||
|
at exit will be deleted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
yield file_paths
|
||||||
|
finally:
|
||||||
|
for path in list(file_paths):
|
||||||
|
try:
|
||||||
|
path.unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error deleting temporary file %s: %s", path, e)
|
||||||
10
gpu/self_hosted/compose.yml
Normal file
10
gpu/self_hosted/compose.yml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
services:
|
||||||
|
reflector_gpu:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- ./cache:/root/.cache
|
||||||
3
gpu/self_hosted/main.py
Normal file
3
gpu/self_hosted/main.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.factory import create_app
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
19
gpu/self_hosted/pyproject.toml
Normal file
19
gpu/self_hosted/pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[project]
|
||||||
|
name = "reflector-gpu"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Self-hosted GPU service for speech transcription, diarization, and translation via FastAPI."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi[standard]>=0.116.1",
|
||||||
|
"uvicorn[standard]>=0.30.0",
|
||||||
|
"torch>=2.3.0",
|
||||||
|
"faster-whisper>=1.1.0",
|
||||||
|
"librosa==0.10.1",
|
||||||
|
"numpy<2",
|
||||||
|
"silero-vad==5.1.0",
|
||||||
|
"transformers>=4.35.0",
|
||||||
|
"sentencepiece",
|
||||||
|
"pyannote.audio==3.1.0",
|
||||||
|
"torchaudio>=2.3.0",
|
||||||
|
]
|
||||||
17
gpu/self_hosted/runserver.sh
Normal file
17
gpu/self_hosted/runserver.sh
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
set -e
|
||||||
|
|
||||||
|
export PATH="/root/.local/bin:$PATH"
|
||||||
|
cd /app
|
||||||
|
|
||||||
|
# Install Python dependencies at runtime (first run or when FORCE_SYNC=1)
|
||||||
|
if [ ! -d "/app/.venv" ] || [ "$FORCE_SYNC" = "1" ]; then
|
||||||
|
echo "[startup] Installing Python dependencies with uv..."
|
||||||
|
uv sync --compile-bytecode --locked
|
||||||
|
else
|
||||||
|
echo "[startup] Using existing virtual environment at /app/.venv"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exec uv run uvicorn main:app --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
|
||||||
3013
gpu/self_hosted/uv.lock
generated
Normal file
3013
gpu/self_hosted/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -190,5 +190,5 @@ Use the pytest-based conformance tests to validate any new implementation (inclu
|
|||||||
```
|
```
|
||||||
TRANSCRIPT_URL=https://<your-deployment-base> \
|
TRANSCRIPT_URL=https://<your-deployment-base> \
|
||||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||||
uv run -m pytest -m gpu_modal --no-cov server/tests/test_gpu_modal_transcript.py
|
uv run -m pytest -m model_api --no-cov server/tests/test_model_api_transcript.py
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
|||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
markers = [
|
markers = [
|
||||||
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
|
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
@@ -130,7 +130,7 @@ select = [
|
|||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
||||||
"gpu/**.py" = ["PLC0415"]
|
"gpu/modal_deployments/**.py" = ["PLC0415"]
|
||||||
"reflector/tools/**.py" = ["PLC0415"]
|
"reflector/tools/**.py" = ["PLC0415"]
|
||||||
"migrations/versions/**.py" = ["PLC0415"]
|
"migrations/versions/**.py" = ["PLC0415"]
|
||||||
"tests/**.py" = ["PLC0415"]
|
"tests/**.py" = ["PLC0415"]
|
||||||
|
|||||||
63
server/tests/test_model_api_diarization.py
Normal file
63
server/tests/test_model_api_diarization.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Tests for diarization Model API endpoint (self-hosted service compatible shape).
|
||||||
|
|
||||||
|
Marked with the "model_api" marker and skipped unless DIARIZATION_URL is provided.
|
||||||
|
|
||||||
|
Run with for local self-hosted server:
|
||||||
|
DIARIZATION_API_KEY=dev-key \
|
||||||
|
DIARIZATION_URL=http://localhost:8000 \
|
||||||
|
uv run -m pytest -m model_api --no-cov tests/test_model_api_diarization.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Public test audio file hosted on S3 specifically for reflector pytests
|
||||||
|
TEST_AUDIO_URL = (
|
||||||
|
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_modal_diarization_url():
|
||||||
|
url = os.environ.get("DIARIZATION_URL")
|
||||||
|
if not url:
|
||||||
|
pytest.skip(
|
||||||
|
"DIARIZATION_URL environment variable is required for Model API tests"
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_headers():
|
||||||
|
api_key = os.environ.get("DIARIZATION_API_KEY") or os.environ.get(
|
||||||
|
"REFLECTOR_GPU_APIKEY"
|
||||||
|
)
|
||||||
|
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.model_api
|
||||||
|
class TestModelAPIDiarization:
|
||||||
|
def test_diarize_from_url(self):
|
||||||
|
url = get_modal_diarization_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/diarize",
|
||||||
|
params={"audio_file_url": TEST_AUDIO_URL, "timestamp": 0.0},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
assert "diarization" in result
|
||||||
|
assert isinstance(result["diarization"], list)
|
||||||
|
assert len(result["diarization"]) > 0
|
||||||
|
|
||||||
|
for seg in result["diarization"]:
|
||||||
|
assert "start" in seg and "end" in seg and "speaker" in seg
|
||||||
|
assert isinstance(seg["start"], (int, float))
|
||||||
|
assert isinstance(seg["end"], (int, float))
|
||||||
|
assert seg["start"] <= seg["end"]
|
||||||
@@ -1,21 +1,21 @@
|
|||||||
"""
|
"""
|
||||||
Tests for GPU Modal transcription endpoints.
|
Tests for transcription Model API endpoints.
|
||||||
|
|
||||||
These tests are marked with the "gpu-modal" group and will not run by default.
|
These tests are marked with the "model_api" group and will not run by default.
|
||||||
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
|
Run them with: pytest -m model_api tests/test_model_api_transcript.py
|
||||||
|
|
||||||
Required environment variables:
|
Required environment variables:
|
||||||
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
|
- TRANSCRIPT_URL: URL to the Model API endpoint (required)
|
||||||
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
|
- TRANSCRIPT_API_KEY: API key for authentication (optional)
|
||||||
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
|
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
|
||||||
|
|
||||||
Example with pytest (override default addopts to run ONLY gpu_modal tests):
|
Example with pytest (override default addopts to run ONLY model_api tests):
|
||||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
|
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
|
||||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
TRANSCRIPT_API_KEY=your-api-key \
|
||||||
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
|
uv run -m pytest -m model_api --no-cov tests/test_model_api_transcript.py
|
||||||
|
|
||||||
# Or with completely clean options:
|
# Or with completely clean options:
|
||||||
uv run -m pytest -m gpu_modal -o addopts="" tests/
|
uv run -m pytest -m model_api -o addopts="" tests/
|
||||||
|
|
||||||
Running Modal locally for testing:
|
Running Modal locally for testing:
|
||||||
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
|
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||||
@@ -40,14 +40,16 @@ def get_modal_transcript_url():
|
|||||||
url = os.environ.get("TRANSCRIPT_URL")
|
url = os.environ.get("TRANSCRIPT_URL")
|
||||||
if not url:
|
if not url:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
|
"TRANSCRIPT_URL environment variable is required for Model API tests"
|
||||||
)
|
)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
def get_auth_headers():
|
def get_auth_headers():
|
||||||
"""Get authentication headers if API key is available."""
|
"""Get authentication headers if API key is available."""
|
||||||
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
|
api_key = os.environ.get("TRANSCRIPT_API_KEY") or os.environ.get(
|
||||||
|
"REFLECTOR_GPU_APIKEY"
|
||||||
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
return {"Authorization": f"Bearer {api_key}"}
|
return {"Authorization": f"Bearer {api_key}"}
|
||||||
return {}
|
return {}
|
||||||
@@ -58,8 +60,8 @@ def get_model_name():
|
|||||||
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
|
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gpu_modal
|
@pytest.mark.model_api
|
||||||
class TestGPUModalTranscript:
|
class TestModelAPITranscript:
|
||||||
"""Test suite for GPU Modal transcription endpoints."""
|
"""Test suite for GPU Modal transcription endpoints."""
|
||||||
|
|
||||||
def test_transcriptions_from_url(self):
|
def test_transcriptions_from_url(self):
|
||||||
56
server/tests/test_model_api_translation.py
Normal file
56
server/tests/test_model_api_translation.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
Tests for translation Model API endpoint (self-hosted service compatible shape).
|
||||||
|
|
||||||
|
Marked with the "model_api" marker and skipped unless TRANSLATION_URL is provided
|
||||||
|
or we fallback to TRANSCRIPT_URL base (same host for self-hosted).
|
||||||
|
|
||||||
|
Run locally against self-hosted server:
|
||||||
|
TRANSLATION_API_KEY=dev-key \
|
||||||
|
TRANSLATION_URL=http://localhost:8000 \
|
||||||
|
uv run -m pytest -m model_api --no-cov tests/test_model_api_translation.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def get_translation_url():
|
||||||
|
url = os.environ.get("TRANSLATION_URL") or os.environ.get("TRANSCRIPT_URL")
|
||||||
|
if not url:
|
||||||
|
pytest.skip(
|
||||||
|
"TRANSLATION_URL or TRANSCRIPT_URL environment variable is required for Model API tests"
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_headers():
|
||||||
|
api_key = os.environ.get("TRANSLATION_API_KEY") or os.environ.get(
|
||||||
|
"REFLECTOR_GPU_APIKEY"
|
||||||
|
)
|
||||||
|
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.model_api
|
||||||
|
class TestModelAPITranslation:
|
||||||
|
def test_translate_text(self):
|
||||||
|
url = get_translation_url()
|
||||||
|
headers = get_auth_headers()
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{url}/translate",
|
||||||
|
params={"text": "The meeting will start in five minutes."},
|
||||||
|
json={"source_language": "en", "target_language": "fr"},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "text" in data and isinstance(data["text"], dict)
|
||||||
|
assert data["text"].get("en") == "The meeting will start in five minutes."
|
||||||
|
assert isinstance(data["text"].get("fr", ""), str)
|
||||||
|
assert len(data["text"]["fr"]) > 0
|
||||||
|
assert data["text"]["fr"] == "La réunion commencera dans cinq minutes."
|
||||||
Reference in New Issue
Block a user