feat: self-hosted gpu api (#636)

* Self-hosted gpu api

* Refactor self-hosted api

* Rename model api tests

* Use lifespan instead of startup event

* Fix self hosted imports

* Add newlines

* Add response models

* Move gpu dir to the root

* Add project description

* Refactor lifespan

* Update env var names for model api tests

* Preload diarizarion service

* Refactor uploaded file paths
This commit is contained in:
2025-09-17 18:52:03 +02:00
committed by GitHub
parent fa049e8d06
commit ab859d65a6
30 changed files with 4020 additions and 16 deletions

View File

@@ -190,5 +190,5 @@ Use the pytest-based conformance tests to validate any new implementation (inclu
```
TRANSCRIPT_URL=https://<your-deployment-base> \
TRANSCRIPT_MODAL_API_KEY=your-api-key \
uv run -m pytest -m gpu_modal --no-cov server/tests/test_gpu_modal_transcript.py
uv run -m pytest -m model_api --no-cov server/tests/test_model_api_transcript.py
```

View File

@@ -1,171 +0,0 @@
# Reflector GPU implementation - Transcription and LLM
This repository hold an API for the GPU implementation of the Reflector API service,
and use [Modal.com](https://modal.com)
- `reflector_diarizer.py` - Diarization API
- `reflector_transcriber.py` - Transcription API (Whisper)
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
- `reflector_translator.py` - Translation API
## Modal.com deployment
Create a modal secret, and name it `reflector-gpu`.
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
The deployment is done using [Modal.com](https://modal.com) service.
```
$ modal deploy reflector_transcriber.py
...
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
$ modal deploy reflector_transcriber_parakeet.py
...
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
$ modal deploy reflector_llm.py
...
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
```
Then in your reflector api configuration `.env`, you can set these keys:
```
TRANSCRIPT_BACKEND=modal
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
DIARIZATION_BACKEND=modal
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
TRANSLATION_BACKEND=modal
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
```
## API
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
```
Authorization: bearer <REFLECTOR_APIKEY>
```
### LLM
`POST /llm`
**request**
```
{
"prompt": "xxx"
}
```
**response**
```
{
"text": "xxx completed"
}
```
### Transcription
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
**GPU Configuration:**
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
- Higher concurrency (max_inputs=10)
- Optimized for multiple small audio files
- Supports batch processing for efficiency
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
- Lower concurrency but more powerful processing
- Optimized for single large audio files
- VAD-based chunking for long-form audio
##### `/v1/audio/transcriptions` - Small file transcription
**request** (multipart/form-data)
- `file` or `files[]` - audio file(s) to transcribe
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
- `language` - language code (default: `en`)
- `batch` - whether to use batch processing for multiple files (default: `true`)
**response**
```json
{
"text": "transcribed text",
"words": [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0}
],
"filename": "audio.mp3"
}
```
For multiple files with batch=true:
```json
{
"results": [
{
"filename": "audio1.mp3",
"text": "transcribed text",
"words": [...]
},
{
"filename": "audio2.mp3",
"text": "transcribed text",
"words": [...]
}
]
}
```
##### `/v1/audio/transcriptions-from-url` - Large file transcription
**request** (application/json)
```json
{
"audio_file_url": "https://example.com/audio.mp3",
"model": "nvidia/parakeet-tdt-0.6b-v2",
"language": "en",
"timestamp_offset": 0.0
}
```
**response**
```json
{
"text": "transcribed text from large file",
"words": [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0}
]
}
```
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
#### Whisper Transcriber (`reflector_transcriber.py`)
`POST /transcribe`
**request** (multipart/form-data)
- `file` - audio file
- `language` - language code (e.g. `en`)
**response**
```
{
"text": "xxx",
"words": [
{"text": "xxx", "start": 0.0, "end": 1.0}
]
}
```

View File

@@ -1,253 +0,0 @@
"""
Reflector GPU backend - diarizer
===================================
"""
import os
import uuid
from typing import Mapping, NewType
from urllib.parse import urlparse
import modal
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
MODEL_DIR = "/root/diarization_models"
UPLOADS_PATH = "/uploads"
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
AudioFileExtension = NewType("AudioFileExtension", str)
app = modal.App(name="reflector-diarizer")
# Volume for temporary file uploads
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
parsed_url = urlparse(url)
url_path = parsed_url.path
for ext in SUPPORTED_FILE_EXTENSIONS:
if url_path.lower().endswith(f".{ext}"):
return AudioFileExtension(ext)
content_type = headers.get("content-type", "").lower()
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
return AudioFileExtension("mp3")
if "audio/wav" in content_type:
return AudioFileExtension("wav")
if "audio/mp4" in content_type:
return AudioFileExtension("mp4")
raise ValueError(
f"Unsupported audio format for URL: {url}. "
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
)
def download_audio_to_volume(
audio_file_url: str,
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
import requests
from fastapi import HTTPException
print(f"Checking audio file at: {audio_file_url}")
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Audio file not found")
print(f"Downloading audio file from: {audio_file_url}")
response = requests.get(audio_file_url, allow_redirects=True)
if response.status_code != 200:
print(f"Download failed with status {response.status_code}: {response.text}")
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to download audio file: {response.status_code}",
)
audio_suffix = detect_audio_format(audio_file_url, response.headers)
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
with open(file_path, "wb") as f:
f.write(response.content)
upload_volume.commit()
print(f"File saved as: {unique_filename}")
return unique_filename, audio_suffix
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
print("LLM cache moved")
def download_pyannote_audio():
from pyannote.audio import Pipeline
Pipeline.from_pretrained(
PYANNOTE_MODEL_NAME,
cache_dir=MODEL_DIR,
use_auth_token=os.environ["HF_TOKEN"],
)
diarizer_image = (
modal.Image.debian_slim(python_version="3.10.8")
.pip_install(
"pyannote.audio==3.1.0",
"requests",
"onnx",
"torchaudio",
"onnxruntime-gpu",
"torch==2.0.0",
"transformers==4.34.0",
"sentencepiece",
"protobuf",
"numpy",
"huggingface_hub",
"hf-transfer",
)
.run_function(
download_pyannote_audio,
secrets=[modal.Secret.from_name("hf_token")],
)
.run_function(migrate_cache_llm)
.env(
{
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
)
}
)
)
@app.cls(
gpu="A100",
timeout=60 * 30,
image=diarizer_image,
volumes={UPLOADS_PATH: upload_volume},
enable_memory_snapshot=True,
experimental_options={"enable_gpu_snapshot": True},
secrets=[
modal.Secret.from_name("hf_token"),
],
)
@modal.concurrent(max_inputs=1)
class Diarizer:
@modal.enter(snap=True)
def enter(self):
import torch
from pyannote.audio import Pipeline
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
print(f"Using device: {self.device}")
self.diarization_pipeline = Pipeline.from_pretrained(
PYANNOTE_MODEL_NAME,
cache_dir=MODEL_DIR,
use_auth_token=os.environ["HF_TOKEN"],
)
self.diarization_pipeline.to(torch.device(self.device))
@modal.method()
def diarize(self, filename: str, timestamp: float = 0.0):
import torchaudio
upload_volume.reload()
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
print(f"Diarizing audio from: {file_path}")
waveform, sample_rate = torchaudio.load(file_path)
diarization = self.diarization_pipeline(
{"waveform": waveform, "sample_rate": sample_rate}
)
words = []
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
words.append(
{
"start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:]),
}
)
print("Diarization complete")
return {"diarization": words}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@app.function(
timeout=60 * 10,
scaledown_window=60 * 3,
secrets=[
modal.Secret.from_name("reflector-gpu"),
],
volumes={UPLOADS_PATH: upload_volume},
image=diarizer_image,
)
@modal.concurrent(max_inputs=40)
@modal.asgi_app()
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
diarizerstub = Diarizer()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class DiarizationResponse(BaseModel):
result: dict
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
try:
func = diarizerstub.diarize.spawn(
filename=unique_filename, timestamp=timestamp
)
result = func.get()
return result
finally:
try:
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Deleting file: {file_path}")
os.remove(file_path)
upload_volume.commit()
except Exception as e:
print(f"Error cleaning up {unique_filename}: {e}")
return app

View File

@@ -1,608 +0,0 @@
import os
import sys
import threading
import uuid
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
from urllib.parse import urlparse
import modal
MODEL_NAME = "large-v2"
MODEL_COMPUTE_TYPE: str = "float16"
MODEL_NUM_WORKERS: int = 1
MINUTES = 60 # seconds
SAMPLERATE = 16000
UPLOADS_PATH = "/uploads"
CACHE_PATH = "/models"
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
VAD_CONFIG = {
"batch_max_duration": 30.0,
"silence_padding": 0.5,
"window_size": 512,
}
WhisperUniqFilename = NewType("WhisperUniqFilename", str)
AudioFileExtension = NewType("AudioFileExtension", str)
app = modal.App("reflector-transcriber")
model_cache = modal.Volume.from_name("models", create_if_missing=True)
upload_volume = modal.Volume.from_name("whisper-uploads", create_if_missing=True)
class TimeSegment(NamedTuple):
"""Represents a time segment with start and end times."""
start: float
end: float
class AudioSegment(NamedTuple):
"""Represents an audio segment with timing and audio data."""
start: float
end: float
audio: any
class TranscriptResult(NamedTuple):
"""Represents a transcription result with text and word timings."""
text: str
words: list["WordTiming"]
class WordTiming(TypedDict):
"""Represents a word with its timing information."""
word: str
start: float
end: float
def download_model():
from faster_whisper import download_model
model_cache.reload()
download_model(MODEL_NAME, cache_dir=CACHE_PATH)
model_cache.commit()
image = (
modal.Image.debian_slim(python_version="3.12")
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
),
}
)
.apt_install("ffmpeg")
.pip_install(
"huggingface_hub==0.27.1",
"hf-transfer==0.1.9",
"torch==2.5.1",
"faster-whisper==1.1.1",
"fastapi==0.115.12",
"requests",
"librosa==0.10.1",
"numpy<2",
"silero-vad==5.1.0",
)
.run_function(download_model, volumes={CACHE_PATH: model_cache})
)
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
parsed_url = urlparse(url)
url_path = parsed_url.path
for ext in SUPPORTED_FILE_EXTENSIONS:
if url_path.lower().endswith(f".{ext}"):
return AudioFileExtension(ext)
content_type = headers.get("content-type", "").lower()
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
return AudioFileExtension("mp3")
if "audio/wav" in content_type:
return AudioFileExtension("wav")
if "audio/mp4" in content_type:
return AudioFileExtension("mp4")
raise ValueError(
f"Unsupported audio format for URL: {url}. "
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
)
def download_audio_to_volume(
audio_file_url: str,
) -> tuple[WhisperUniqFilename, AudioFileExtension]:
import requests
from fastapi import HTTPException
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Audio file not found")
response = requests.get(audio_file_url, allow_redirects=True)
response.raise_for_status()
audio_suffix = detect_audio_format(audio_file_url, response.headers)
unique_filename = WhisperUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
file_path = f"{UPLOADS_PATH}/{unique_filename}"
with open(file_path, "wb") as f:
f.write(response.content)
upload_volume.commit()
return unique_filename, audio_suffix
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
"""Add 0.5s of silence if audio is shorter than the silence_padding window.
Whisper does not require this strictly, but aligning behavior with Parakeet
avoids edge-case crashes on extremely short inputs and makes comparisons easier.
"""
import numpy as np
audio_duration = len(audio_array) / sample_rate
if audio_duration < VAD_CONFIG["silence_padding"]:
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
silence = np.zeros(silence_samples, dtype=np.float32)
return np.concatenate([audio_array, silence])
return audio_array
@app.cls(
gpu="A10G",
timeout=5 * MINUTES,
scaledown_window=5 * MINUTES,
image=image,
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
)
@modal.concurrent(max_inputs=10)
class TranscriberWhisperLive:
"""Live transcriber class for small audio segments (A10G).
Mirrors the Parakeet live class API but uses Faster-Whisper under the hood.
"""
@modal.enter()
def enter(self):
import faster_whisper
import torch
self.lock = threading.Lock()
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.model = faster_whisper.WhisperModel(
MODEL_NAME,
device=self.device,
compute_type=MODEL_COMPUTE_TYPE,
num_workers=MODEL_NUM_WORKERS,
download_root=CACHE_PATH,
local_files_only=True,
)
print(f"Model is on device: {self.device}")
@modal.method()
def transcribe_segment(
self,
filename: str,
language: str = "en",
):
"""Transcribe a single uploaded audio file by filename."""
upload_volume.reload()
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with self.lock:
with NoStdStreams():
segments, _ = self.model.transcribe(
file_path,
language=language,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
segments = list(segments)
text = "".join(segment.text for segment in segments).strip()
words = [
{
"word": word.word,
"start": round(float(word.start), 2),
"end": round(float(word.end), 2),
}
for segment in segments
for word in segment.words
]
return {"text": text, "words": words}
@modal.method()
def transcribe_batch(
self,
filenames: list[str],
language: str = "en",
):
"""Transcribe multiple uploaded audio files and return per-file results."""
upload_volume.reload()
results = []
for filename in filenames:
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"Batch file not found: {file_path}")
with self.lock:
with NoStdStreams():
segments, _ = self.model.transcribe(
file_path,
language=language,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
segments = list(segments)
text = "".join(seg.text for seg in segments).strip()
words = [
{
"word": w.word,
"start": round(float(w.start), 2),
"end": round(float(w.end), 2),
}
for seg in segments
for w in seg.words
]
results.append(
{
"filename": filename,
"text": text,
"words": words,
}
)
return results
@app.cls(
gpu="L40S",
timeout=15 * MINUTES,
image=image,
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
)
class TranscriberWhisperFile:
"""File transcriber for larger/longer audio, using VAD-driven batching (L40S)."""
@modal.enter()
def enter(self):
import faster_whisper
import torch
from silero_vad import load_silero_vad
self.lock = threading.Lock()
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.model = faster_whisper.WhisperModel(
MODEL_NAME,
device=self.device,
compute_type=MODEL_COMPUTE_TYPE,
num_workers=MODEL_NUM_WORKERS,
download_root=CACHE_PATH,
local_files_only=True,
)
self.vad_model = load_silero_vad(onnx=False)
@modal.method()
def transcribe_segment(
self, filename: str, timestamp_offset: float = 0.0, language: str = "en"
):
import librosa
import numpy as np
from silero_vad import VADIterator
def vad_segments(
audio_array,
sample_rate: int = SAMPLERATE,
window_size: int = VAD_CONFIG["window_size"],
) -> Generator[TimeSegment, None, None]:
"""Generate speech segments as TimeSegment using Silero VAD."""
iterator = VADIterator(self.vad_model, sampling_rate=sample_rate)
start = None
for i in range(0, len(audio_array), window_size):
chunk = audio_array[i : i + window_size]
if len(chunk) < window_size:
chunk = np.pad(
chunk, (0, window_size - len(chunk)), mode="constant"
)
speech = iterator(chunk)
if not speech:
continue
if "start" in speech:
start = speech["start"]
continue
if "end" in speech and start is not None:
end = speech["end"]
yield TimeSegment(
start / float(SAMPLERATE), end / float(SAMPLERATE)
)
start = None
iterator.reset_states()
upload_volume.reload()
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
audio_array, _sr = librosa.load(file_path, sr=SAMPLERATE, mono=True)
# Batch segments up to ~30s windows by merging contiguous VAD segments
merged_batches: list[TimeSegment] = []
batch_start = None
batch_end = None
max_duration = VAD_CONFIG["batch_max_duration"]
for segment in vad_segments(audio_array):
seg_start, seg_end = segment.start, segment.end
if batch_start is None:
batch_start, batch_end = seg_start, seg_end
continue
if seg_end - batch_start <= max_duration:
batch_end = seg_end
else:
merged_batches.append(TimeSegment(batch_start, batch_end))
batch_start, batch_end = seg_start, seg_end
if batch_start is not None and batch_end is not None:
merged_batches.append(TimeSegment(batch_start, batch_end))
all_text = []
all_words = []
for segment in merged_batches:
start_time, end_time = segment.start, segment.end
s_idx = int(start_time * SAMPLERATE)
e_idx = int(end_time * SAMPLERATE)
segment = audio_array[s_idx:e_idx]
segment = pad_audio(segment, SAMPLERATE)
with self.lock:
segments, _ = self.model.transcribe(
segment,
language=language,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
segments = list(segments)
text = "".join(seg.text for seg in segments).strip()
words = [
{
"word": w.word,
"start": round(float(w.start) + start_time + timestamp_offset, 2),
"end": round(float(w.end) + start_time + timestamp_offset, 2),
}
for seg in segments
for w in seg.words
]
if text:
all_text.append(text)
all_words.extend(words)
return {"text": " ".join(all_text), "words": all_words}
def detect_audio_format(url: str, headers: dict) -> str:
from urllib.parse import urlparse
from fastapi import HTTPException
url_path = urlparse(url).path
for ext in SUPPORTED_FILE_EXTENSIONS:
if url_path.lower().endswith(f".{ext}"):
return ext
content_type = headers.get("content-type", "").lower()
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
return "mp3"
if "audio/wav" in content_type:
return "wav"
if "audio/mp4" in content_type:
return "mp4"
raise HTTPException(
status_code=400,
detail=(
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
),
)
def download_audio_to_volume(audio_file_url: str) -> tuple[str, str]:
import requests
from fastapi import HTTPException
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Audio file not found")
response = requests.get(audio_file_url, allow_redirects=True)
response.raise_for_status()
audio_suffix = detect_audio_format(audio_file_url, response.headers)
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
file_path = f"{UPLOADS_PATH}/{unique_filename}"
with open(file_path, "wb") as f:
f.write(response.content)
upload_volume.commit()
return unique_filename, audio_suffix
@app.function(
scaledown_window=60,
timeout=600,
secrets=[
modal.Secret.from_name("reflector-gpu"),
],
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
image=image,
)
@modal.concurrent(max_inputs=40)
@modal.asgi_app()
def web():
from fastapi import (
Body,
Depends,
FastAPI,
Form,
HTTPException,
UploadFile,
status,
)
from fastapi.security import OAuth2PasswordBearer
transcriber_live = TranscriberWhisperLive()
transcriber_file = TranscriberWhisperFile()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class TranscriptResponse(dict):
pass
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
def transcribe(
file: UploadFile = None,
files: list[UploadFile] | None = None,
model: str = Form(MODEL_NAME),
language: str = Form("en"),
batch: bool = Form(False),
):
if not file and not files:
raise HTTPException(
status_code=400, detail="Either 'file' or 'files' parameter is required"
)
if batch and not files:
raise HTTPException(
status_code=400, detail="Batch transcription requires 'files'"
)
upload_files = [file] if file else files
uploaded_filenames: list[str] = []
for upload_file in upload_files:
audio_suffix = upload_file.filename.split(".")[-1]
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=(
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
),
)
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
file_path = f"{UPLOADS_PATH}/{unique_filename}"
with open(file_path, "wb") as f:
content = upload_file.file.read()
f.write(content)
uploaded_filenames.append(unique_filename)
upload_volume.commit()
try:
if batch and len(upload_files) > 1:
func = transcriber_live.transcribe_batch.spawn(
filenames=uploaded_filenames,
language=language,
)
results = func.get()
return {"results": results}
results = []
for filename in uploaded_filenames:
func = transcriber_live.transcribe_segment.spawn(
filename=filename,
language=language,
)
result = func.get()
result["filename"] = filename
results.append(result)
return {"results": results} if len(results) > 1 else results[0]
finally:
for filename in uploaded_filenames:
try:
file_path = f"{UPLOADS_PATH}/{filename}"
os.remove(file_path)
except Exception:
pass
upload_volume.commit()
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
def transcribe_from_url(
audio_file_url: str = Body(
..., description="URL of the audio file to transcribe"
),
model: str = Body(MODEL_NAME),
language: str = Body("en"),
timestamp_offset: float = Body(0.0),
):
unique_filename, _audio_suffix = download_audio_to_volume(audio_file_url)
try:
func = transcriber_file.transcribe_segment.spawn(
filename=unique_filename,
timestamp_offset=timestamp_offset,
language=language,
)
result = func.get()
return result
finally:
try:
file_path = f"{UPLOADS_PATH}/{unique_filename}"
os.remove(file_path)
upload_volume.commit()
except Exception:
pass
return app
class NoStdStreams:
def __init__(self):
self.devnull = open(os.devnull, "w")
def __enter__(self):
self._stdout, self._stderr = sys.stdout, sys.stderr
self._stdout.flush()
self._stderr.flush()
sys.stdout, sys.stderr = self.devnull, self.devnull
def __exit__(self, exc_type, exc_value, traceback):
sys.stdout, sys.stderr = self._stdout, self._stderr
self.devnull.close()

View File

@@ -1,658 +0,0 @@
import logging
import os
import sys
import threading
import uuid
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
from urllib.parse import urlparse
import modal
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
SAMPLERATE = 16000
UPLOADS_PATH = "/uploads"
CACHE_PATH = "/cache"
VAD_CONFIG = {
"batch_max_duration": 30.0,
"silence_padding": 0.5,
"window_size": 512,
}
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
AudioFileExtension = NewType("AudioFileExtension", str)
class TimeSegment(NamedTuple):
"""Represents a time segment with start and end times."""
start: float
end: float
class AudioSegment(NamedTuple):
"""Represents an audio segment with timing and audio data."""
start: float
end: float
audio: any
class TranscriptResult(NamedTuple):
"""Represents a transcription result with text and word timings."""
text: str
words: list["WordTiming"]
class WordTiming(TypedDict):
"""Represents a word with its timing information."""
word: str
start: float
end: float
app = modal.App("reflector-transcriber-parakeet")
# Volume for caching model weights
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
# Volume for temporary file uploads
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
image = (
modal.Image.from_registry(
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
)
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"HF_HOME": "/cache",
"DEBIAN_FRONTEND": "noninteractive",
"CXX": "g++",
"CC": "g++",
}
)
.apt_install("ffmpeg")
.pip_install(
"hf_transfer==0.1.9",
"huggingface_hub[hf-xet]==0.31.2",
"nemo_toolkit[asr]==2.3.0",
"cuda-python==12.8.0",
"fastapi==0.115.12",
"numpy<2",
"librosa==0.10.1",
"requests",
"silero-vad==5.1.0",
"torch",
)
.entrypoint([]) # silence chatty logs by container on start
)
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
parsed_url = urlparse(url)
url_path = parsed_url.path
for ext in SUPPORTED_FILE_EXTENSIONS:
if url_path.lower().endswith(f".{ext}"):
return AudioFileExtension(ext)
content_type = headers.get("content-type", "").lower()
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
return AudioFileExtension("mp3")
if "audio/wav" in content_type:
return AudioFileExtension("wav")
if "audio/mp4" in content_type:
return AudioFileExtension("mp4")
raise ValueError(
f"Unsupported audio format for URL: {url}. "
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
)
def download_audio_to_volume(
audio_file_url: str,
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
import requests
from fastapi import HTTPException
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Audio file not found")
response = requests.get(audio_file_url, allow_redirects=True)
response.raise_for_status()
audio_suffix = detect_audio_format(audio_file_url, response.headers)
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
file_path = f"{UPLOADS_PATH}/{unique_filename}"
with open(file_path, "wb") as f:
f.write(response.content)
upload_volume.commit()
return unique_filename, audio_suffix
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
"""Add 0.5 seconds of silence if audio is less than 500ms.
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
have to be of the same length
See: https://github.com/NVIDIA/NeMo/issues/8451
"""
import numpy as np
audio_duration = len(audio_array) / sample_rate
if audio_duration < 0.5:
silence_samples = int(sample_rate * 0.5)
silence = np.zeros(silence_samples, dtype=np.float32)
return np.concatenate([audio_array, silence])
return audio_array
@app.cls(
gpu="A10G",
timeout=600,
scaledown_window=300,
image=image,
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
enable_memory_snapshot=True,
experimental_options={"enable_gpu_snapshot": True},
)
@modal.concurrent(max_inputs=10)
class TranscriberParakeetLive:
@modal.enter(snap=True)
def enter(self):
import nemo.collections.asr as nemo_asr
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
self.lock = threading.Lock()
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
device = next(self.model.parameters()).device
print(f"Model is on device: {device}")
@modal.method()
def transcribe_segment(
self,
filename: str,
):
import librosa
upload_volume.reload()
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
padded_audio = pad_audio(audio_array, sample_rate)
with self.lock:
with NoStdStreams():
(output,) = self.model.transcribe([padded_audio], timestamps=True)
text = output.text.strip()
words: list[WordTiming] = [
WordTiming(
# XXX the space added here is to match the output of whisper
# whisper add space to each words, while parakeet don't
word=word_info["word"] + " ",
start=round(word_info["start"], 2),
end=round(word_info["end"], 2),
)
for word_info in output.timestamp["word"]
]
return {"text": text, "words": words}
@modal.method()
def transcribe_batch(
self,
filenames: list[str],
):
import librosa
upload_volume.reload()
results = []
audio_arrays = []
# Load all audio files with padding
for filename in filenames:
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"Batch file not found: {file_path}")
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
padded_audio = pad_audio(audio_array, sample_rate)
audio_arrays.append(padded_audio)
with self.lock:
with NoStdStreams():
outputs = self.model.transcribe(audio_arrays, timestamps=True)
# Process results for each file
for i, (filename, output) in enumerate(zip(filenames, outputs)):
text = output.text.strip()
words: list[WordTiming] = [
WordTiming(
word=word_info["word"] + " ",
start=round(word_info["start"], 2),
end=round(word_info["end"], 2),
)
for word_info in output.timestamp["word"]
]
results.append(
{
"filename": filename,
"text": text,
"words": words,
}
)
return results
# L40S class for file transcription (bigger files)
@app.cls(
gpu="L40S",
timeout=900,
image=image,
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
enable_memory_snapshot=True,
experimental_options={"enable_gpu_snapshot": True},
)
class TranscriberParakeetFile:
@modal.enter(snap=True)
def enter(self):
import nemo.collections.asr as nemo_asr
import torch
from silero_vad import load_silero_vad
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
device = next(self.model.parameters()).device
print(f"Model is on device: {device}")
torch.set_num_threads(1)
self.vad_model = load_silero_vad(onnx=False)
print("Silero VAD initialized")
@modal.method()
def transcribe_segment(
self,
filename: str,
timestamp_offset: float = 0.0,
):
import librosa
import numpy as np
from silero_vad import VADIterator
def load_and_convert_audio(file_path):
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
return audio_array
def vad_segment_generator(
audio_array,
) -> Generator[TimeSegment, None, None]:
"""Generate speech segments using VAD with start/end sample indices"""
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
window_size = VAD_CONFIG["window_size"]
start = None
for i in range(0, len(audio_array), window_size):
chunk = audio_array[i : i + window_size]
if len(chunk) < window_size:
chunk = np.pad(
chunk, (0, window_size - len(chunk)), mode="constant"
)
speech_dict = vad_iterator(chunk)
if not speech_dict:
continue
if "start" in speech_dict:
start = speech_dict["start"]
continue
if "end" in speech_dict and start is not None:
end = speech_dict["end"]
start_time = start / float(SAMPLERATE)
end_time = end / float(SAMPLERATE)
yield TimeSegment(start_time, end_time)
start = None
vad_iterator.reset_states()
def batch_speech_segments(
segments: Generator[TimeSegment, None, None], max_duration: int
) -> Generator[TimeSegment, None, None]:
"""
Input segments:
[0-2] [3-5] [6-8] [10-11] [12-15] [17-19] [20-22]
↓ (max_duration=10)
Output batches:
[0-8] [10-19] [20-22]
Note: silences are kept for better transcription, previous implementation was
passing segments separatly, but the output was less accurate.
"""
batch_start_time = None
batch_end_time = None
for segment in segments:
start_time, end_time = segment.start, segment.end
if batch_start_time is None or batch_end_time is None:
batch_start_time = start_time
batch_end_time = end_time
continue
total_duration = end_time - batch_start_time
if total_duration <= max_duration:
batch_end_time = end_time
continue
yield TimeSegment(batch_start_time, batch_end_time)
batch_start_time = start_time
batch_end_time = end_time
if batch_start_time is None or batch_end_time is None:
return
yield TimeSegment(batch_start_time, batch_end_time)
def batch_segment_to_audio_segment(
segments: Generator[TimeSegment, None, None],
audio_array,
) -> Generator[AudioSegment, None, None]:
"""Extract audio segments and apply padding for Parakeet compatibility.
Uses pad_audio to ensure segments are at least 0.5s long, preventing
Parakeet crashes. This padding may cause slight timing overlaps between
segments, which are corrected by enforce_word_timing_constraints.
"""
for segment in segments:
start_time, end_time = segment.start, segment.end
start_sample = int(start_time * SAMPLERATE)
end_sample = int(end_time * SAMPLERATE)
audio_segment = audio_array[start_sample:end_sample]
padded_segment = pad_audio(audio_segment, SAMPLERATE)
yield AudioSegment(start_time, end_time, padded_segment)
def transcribe_batch(model, audio_segments: list) -> list:
with NoStdStreams():
outputs = model.transcribe(audio_segments, timestamps=True)
return outputs
def enforce_word_timing_constraints(
words: list[WordTiming],
) -> list[WordTiming]:
"""Enforce that word end times don't exceed the start time of the next word.
Due to silence padding added in batch_segment_to_audio_segment for better
transcription accuracy, word timings from different segments may overlap.
This function ensures there are no overlaps by adjusting end times.
"""
if len(words) <= 1:
return words
enforced_words = []
for i, word in enumerate(words):
enforced_word = word.copy()
if i < len(words) - 1:
next_start = words[i + 1]["start"]
if enforced_word["end"] > next_start:
enforced_word["end"] = next_start
enforced_words.append(enforced_word)
return enforced_words
def emit_results(
results: list,
segments_info: list[AudioSegment],
) -> Generator[TranscriptResult, None, None]:
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
for i, (output, segment) in enumerate(zip(results, segments_info)):
start_time, end_time = segment.start, segment.end
text = output.text.strip()
words: list[WordTiming] = [
WordTiming(
word=word_info["word"] + " ",
start=round(
word_info["start"] + start_time + timestamp_offset, 2
),
end=round(word_info["end"] + start_time + timestamp_offset, 2),
)
for word_info in output.timestamp["word"]
]
yield TranscriptResult(text, words)
upload_volume.reload()
file_path = f"{UPLOADS_PATH}/{filename}"
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
audio_array = load_and_convert_audio(file_path)
total_duration = len(audio_array) / float(SAMPLERATE)
all_text_parts: list[str] = []
all_words: list[WordTiming] = []
raw_segments = vad_segment_generator(audio_array)
speech_segments = batch_speech_segments(
raw_segments,
VAD_CONFIG["batch_max_duration"],
)
audio_segments = batch_segment_to_audio_segment(speech_segments, audio_array)
for batch in audio_segments:
audio_segment = batch.audio
results = transcribe_batch(self.model, [audio_segment])
for result in emit_results(
results,
[batch],
):
if not result.text:
continue
all_text_parts.append(result.text)
all_words.extend(result.words)
all_words = enforce_word_timing_constraints(all_words)
combined_text = " ".join(all_text_parts)
return {"text": combined_text, "words": all_words}
@app.function(
scaledown_window=60,
timeout=600,
secrets=[
modal.Secret.from_name("reflector-gpu"),
],
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
image=image,
)
@modal.concurrent(max_inputs=40)
@modal.asgi_app()
def web():
import os
import uuid
from fastapi import (
Body,
Depends,
FastAPI,
Form,
HTTPException,
UploadFile,
status,
)
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
transcriber_live = TranscriberParakeetLive()
transcriber_file = TranscriberParakeetFile()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class TranscriptResponse(BaseModel):
result: dict
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
def transcribe(
file: UploadFile = None,
files: list[UploadFile] | None = None,
model: str = Form(MODEL_NAME),
language: str = Form("en"),
batch: bool = Form(False),
):
# Parakeet only supports English
if language != "en":
raise HTTPException(
status_code=400,
detail=f"Parakeet model only supports English. Got language='{language}'",
)
# Handle both single file and multiple files
if not file and not files:
raise HTTPException(
status_code=400, detail="Either 'file' or 'files' parameter is required"
)
if batch and not files:
raise HTTPException(
status_code=400, detail="Batch transcription requires 'files'"
)
upload_files = [file] if file else files
# Upload files to volume
uploaded_filenames = []
for upload_file in upload_files:
audio_suffix = upload_file.filename.split(".")[-1]
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
# Generate unique filename
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Writing file to: {file_path}")
with open(file_path, "wb") as f:
content = upload_file.file.read()
f.write(content)
uploaded_filenames.append(unique_filename)
upload_volume.commit()
try:
# Use A10G live transcriber for per-file transcription
if batch and len(upload_files) > 1:
# Use batch transcription
func = transcriber_live.transcribe_batch.spawn(
filenames=uploaded_filenames,
)
results = func.get()
return {"results": results}
# Per-file transcription
results = []
for filename in uploaded_filenames:
func = transcriber_live.transcribe_segment.spawn(
filename=filename,
)
result = func.get()
result["filename"] = filename
results.append(result)
return {"results": results} if len(results) > 1 else results[0]
finally:
for filename in uploaded_filenames:
try:
file_path = f"{UPLOADS_PATH}/{filename}"
print(f"Deleting file: {file_path}")
os.remove(file_path)
except Exception as e:
print(f"Error deleting {filename}: {e}")
upload_volume.commit()
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
def transcribe_from_url(
audio_file_url: str = Body(
..., description="URL of the audio file to transcribe"
),
model: str = Body(MODEL_NAME),
language: str = Body("en", description="Language code (only 'en' supported)"),
timestamp_offset: float = Body(0.0),
):
# Parakeet only supports English
if language != "en":
raise HTTPException(
status_code=400,
detail=f"Parakeet model only supports English. Got language='{language}'",
)
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
try:
func = transcriber_file.transcribe_segment.spawn(
filename=unique_filename,
timestamp_offset=timestamp_offset,
)
result = func.get()
return result
finally:
try:
file_path = f"{UPLOADS_PATH}/{unique_filename}"
print(f"Deleting file: {file_path}")
os.remove(file_path)
upload_volume.commit()
except Exception as e:
print(f"Error cleaning up {unique_filename}: {e}")
return app
class NoStdStreams:
def __init__(self):
self.devnull = open(os.devnull, "w")
def __enter__(self):
self._stdout, self._stderr = sys.stdout, sys.stderr
self._stdout.flush()
self._stderr.flush()
sys.stdout, sys.stderr = self.devnull, self.devnull
def __exit__(self, exc_type, exc_value, traceback):
sys.stdout, sys.stderr = self._stdout, self._stderr
self.devnull.close()

View File

@@ -1,428 +0,0 @@
"""
Reflector GPU backend - transcriber
===================================
"""
import os
import threading
from modal import App, Image, Secret, asgi_app, enter, method
from pydantic import BaseModel
# Seamless M4T
SEAMLESSM4T_MODEL_SIZE: str = "medium"
SEAMLESSM4T_MODEL_CARD_NAME: str = f"seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}"
SEAMLESSM4T_VOCODER_CARD_NAME: str = "vocoder_36langs"
HF_SEAMLESS_M4TEPO: str = f"facebook/seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}"
HF_SEAMLESS_M4T_VOCODEREPO: str = "facebook/seamless-m4t-vocoder"
SEAMLESS_GITEPO: str = "https://github.com/facebookresearch/seamless_communication.git"
SEAMLESS_MODEL_DIR: str = "m4t"
app = App(name="reflector-translator")
def install_seamless_communication():
import os
import subprocess
initial_dir = os.getcwd()
subprocess.run(
["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"]
)
subprocess.run(["rm", "-rf", "seamless_communication"])
subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"])
os.chdir("seamless_communication")
subprocess.run(["pip", "install", "-e", "."])
os.chdir(initial_dir)
def download_seamlessm4t_model():
from huggingface_hub import snapshot_download
print("Downloading Transcriber model & tokenizer")
snapshot_download(HF_SEAMLESS_M4TEPO, cache_dir=SEAMLESS_MODEL_DIR)
print("Transcriber model & tokenizer downloaded")
print("Downloading vocoder weights")
snapshot_download(HF_SEAMLESS_M4T_VOCODEREPO, cache_dir=SEAMLESS_MODEL_DIR)
print("Vocoder weights downloaded")
def configure_seamless_m4t():
import os
import yaml
CARDS_DIR: str = "./seamless_communication/src/seamless_communication/cards"
with open(f"{CARDS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "r") as file:
model_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f"{CARDS_DIR}/vocoder_36langs.yaml", "r") as file:
vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f"{CARDS_DIR}/unity_nllb-100.yaml", "r") as file:
unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f"{CARDS_DIR}/unity_nllb-200.yaml", "r") as file:
unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots"
available_model_versions = os.listdir(model_dir)
latest_model_version = sorted(available_model_versions)[-1]
model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt"
model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name)
vocoder_dir = (
f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots"
)
available_vocoder_versions = os.listdir(vocoder_dir)
latest_vocoder_version = sorted(available_vocoder_versions)[-1]
vocoder_name = "vocoder_36langs.pt"
vocoder_path = os.path.join(
os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name
)
tokenizer_name = "tokenizer.model"
tokenizer_path = os.path.join(
os.getcwd(), model_dir, latest_model_version, tokenizer_name
)
model_yaml_data["checkpoint"] = f"file://{model_path}"
vocoder_yaml_data["checkpoint"] = f"file://{vocoder_path}"
unity_100_yaml_data["tokenizer"] = f"file://{tokenizer_path}"
unity_200_yaml_data["tokenizer"] = f"file://{tokenizer_path}"
with open(f"{CARDS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "w") as file:
yaml.dump(model_yaml_data, file)
with open(f"{CARDS_DIR}/vocoder_36langs.yaml", "w") as file:
yaml.dump(vocoder_yaml_data, file)
with open(f"{CARDS_DIR}/unity_nllb-100.yaml", "w") as file:
yaml.dump(unity_100_yaml_data, file)
with open(f"{CARDS_DIR}/unity_nllb-200.yaml", "w") as file:
yaml.dump(unity_200_yaml_data, file)
transcriber_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.apt_install("wget")
.apt_install("libsndfile-dev")
.pip_install(
"requests",
"torch",
"transformers==4.34.0",
"sentencepiece",
"protobuf",
"huggingface_hub==0.16.4",
"gitpython",
"torchaudio",
"fairseq2",
"pyyaml",
"hf-transfer~=0.1",
)
.run_function(install_seamless_communication)
.run_function(download_seamlessm4t_model)
.run_function(configure_seamless_m4t)
.env(
{
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
)
}
)
)
@app.cls(
gpu="A10G",
timeout=60 * 5,
scaledown_window=60 * 5,
allow_concurrent_inputs=4,
image=transcriber_image,
)
class Translator:
@enter()
def enter(self):
import torch
from seamless_communication.inference.translator import Translator
self.lock = threading.Lock()
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.translator = Translator(
SEAMLESSM4T_MODEL_CARD_NAME,
SEAMLESSM4T_VOCODER_CARD_NAME,
torch.device(self.device),
dtype=torch.float32,
)
@method()
def warmup(self):
return {"status": "ok"}
def get_seamless_lang_code(self, lang_code: str):
"""
The codes for SeamlessM4T is different from regular standards.
For ex, French is "fra" and not "fr".
"""
# TODO: Enhance with complete list of lang codes
seamless_lang_code = {
# Afrikaans
"af": "afr",
# Amharic
"am": "amh",
# Modern Standard Arabic
"ar": "arb",
# Moroccan Arabic
"ary": "ary",
# Egyptian Arabic
"arz": "arz",
# Assamese
"as": "asm",
# North Azerbaijani
"az": "azj",
# Belarusian
"be": "bel",
# Bengali
"bn": "ben",
# Bosnian
"bs": "bos",
# Bulgarian
"bg": "bul",
# Catalan
"ca": "cat",
# Cebuano
"ceb": "ceb",
# Czech
"cs": "ces",
# Central Kurdish
"ku": "ckb",
# Mandarin Chinese
"cmn": "cmn_Hant",
# Welsh
"cy": "cym",
# Danish
"da": "dan",
# German
"de": "deu",
# Greek
"el": "ell",
# English
"en": "eng",
# Estonian
"et": "est",
# Basque
"eu": "eus",
# Finnish
"fi": "fin",
# French
"fr": "fra",
# Irish
"ga": "gle",
# West Central Oromo,
"gaz": "gaz",
# Galician
"gl": "glg",
# Gujarati
"gu": "guj",
# Hebrew
"he": "heb",
# Hindi
"hi": "hin",
# Croatian
"hr": "hrv",
# Hungarian
"hu": "hun",
# Armenian
"hy": "hye",
# Igbo
"ig": "ibo",
# Indonesian
"id": "ind",
# Icelandic
"is": "isl",
# Italian
"it": "ita",
# Javanese
"jv": "jav",
# Japanese
"ja": "jpn",
# Kannada
"kn": "kan",
# Georgian
"ka": "kat",
# Kazakh
"kk": "kaz",
# Halh Mongolian
"khk": "khk",
# Khmer
"km": "khm",
# Kyrgyz
"ky": "kir",
# Korean
"ko": "kor",
# Lao
"lo": "lao",
# Lithuanian
"lt": "lit",
# Ganda
"lg": "lug",
# Luo
"luo": "luo",
# Standard Latvian
"lv": "lvs",
# Maithili
"mai": "mai",
# Malayalam
"ml": "mal",
# Marathi
"mr": "mar",
# Macedonian
"mk": "mkd",
# Maltese
"mt": "mlt",
# Meitei
"mni": "mni",
# Burmese
"my": "mya",
# Dutch
"nl": "nld",
# Norwegian Nynorsk
"nn": "nno",
# Norwegian Bokmål
"nb": "nob",
# Nepali
"ne": "npi",
# Nyanja
"ny": "nya",
# Odia
"or": "ory",
# Punjabi
"pa": "pan",
# Southern Pashto
"pbt": "pbt",
# Western Persian
"pes": "pes",
# Polish
"pl": "pol",
# Portuguese
"pt": "por",
# Romanian
"ro": "ron",
# Russian
"ru": "rus",
# Slovak
"sk": "slk",
# Slovenian
"sl": "slv",
# Shona
"sn": "sna",
# Sindhi
"sd": "snd",
# Somali
"so": "som",
# Spanish
"es": "spa",
# Serbian
"sr": "srp",
# Swedish
"sv": "swe",
# Swahili
"sw": "swh",
# Tamil
"ta": "tam",
# Telugu
"te": "tel",
# Tajik
"tg": "tgk",
# Tagalog
"tl": "tgl",
# Thai
"th": "tha",
# Turkish
"tr": "tur",
# Ukrainian
"uk": "ukr",
# Urdu
"ur": "urd",
# Northern Uzbek
"uz": "uzn",
# Vietnamese
"vi": "vie",
# Yoruba
"yo": "yor",
# Cantonese
"yue": "yue",
# Standard Malay
"ms": "zsm",
# Zulu
"zu": "zul",
}
return seamless_lang_code.get(lang_code, "eng")
@method()
def translate_text(self, text: str, source_language: str, target_language: str):
with self.lock:
translation_result, _ = self.translator.predict(
text,
"t2tt",
src_lang=self.get_seamless_lang_code(source_language),
tgt_lang=self.get_seamless_lang_code(target_language),
unit_generation_ngram_filtering=True,
)
translated_text = str(translation_result[0])
return {"text": {source_language: text, target_language: translated_text}}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@app.function(
scaledown_window=60,
timeout=60,
allow_concurrent_inputs=40,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Body, Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated
translatorstub = Translator()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class TranslateResponse(BaseModel):
result: dict
@app.post("/translate", dependencies=[Depends(apikey_auth)])
async def translate(
text: str,
source_language: Annotated[str, Body(...)] = "en",
target_language: Annotated[str, Body(...)] = "fr",
) -> TranslateResponse:
func = translatorstub.translate_text.spawn(
text=text,
source_language=source_language,
target_language=target_language,
)
result = func.get()
return result
return app

View File

@@ -118,7 +118,7 @@ addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"]
asyncio_mode = "auto"
markers = [
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
]
[tool.ruff.lint]
@@ -130,7 +130,7 @@ select = [
[tool.ruff.lint.per-file-ignores]
"reflector/processors/summary/summary_builder.py" = ["E501"]
"gpu/**.py" = ["PLC0415"]
"gpu/modal_deployments/**.py" = ["PLC0415"]
"reflector/tools/**.py" = ["PLC0415"]
"migrations/versions/**.py" = ["PLC0415"]
"tests/**.py" = ["PLC0415"]

View File

@@ -0,0 +1,63 @@
"""
Tests for diarization Model API endpoint (self-hosted service compatible shape).
Marked with the "model_api" marker and skipped unless DIARIZATION_URL is provided.
Run with for local self-hosted server:
DIARIZATION_API_KEY=dev-key \
DIARIZATION_URL=http://localhost:8000 \
uv run -m pytest -m model_api --no-cov tests/test_model_api_diarization.py
"""
import os
import httpx
import pytest
# Public test audio file hosted on S3 specifically for reflector pytests
TEST_AUDIO_URL = (
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
)
def get_modal_diarization_url():
url = os.environ.get("DIARIZATION_URL")
if not url:
pytest.skip(
"DIARIZATION_URL environment variable is required for Model API tests"
)
return url
def get_auth_headers():
api_key = os.environ.get("DIARIZATION_API_KEY") or os.environ.get(
"REFLECTOR_GPU_APIKEY"
)
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
@pytest.mark.model_api
class TestModelAPIDiarization:
def test_diarize_from_url(self):
url = get_modal_diarization_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{url}/diarize",
params={"audio_file_url": TEST_AUDIO_URL, "timestamp": 0.0},
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
result = response.json()
assert "diarization" in result
assert isinstance(result["diarization"], list)
assert len(result["diarization"]) > 0
for seg in result["diarization"]:
assert "start" in seg and "end" in seg and "speaker" in seg
assert isinstance(seg["start"], (int, float))
assert isinstance(seg["end"], (int, float))
assert seg["start"] <= seg["end"]

View File

@@ -1,21 +1,21 @@
"""
Tests for GPU Modal transcription endpoints.
Tests for transcription Model API endpoints.
These tests are marked with the "gpu-modal" group and will not run by default.
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
These tests are marked with the "model_api" group and will not run by default.
Run them with: pytest -m model_api tests/test_model_api_transcript.py
Required environment variables:
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
- TRANSCRIPT_URL: URL to the Model API endpoint (required)
- TRANSCRIPT_API_KEY: API key for authentication (optional)
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
Example with pytest (override default addopts to run ONLY gpu_modal tests):
Example with pytest (override default addopts to run ONLY model_api tests):
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
TRANSCRIPT_MODAL_API_KEY=your-api-key \
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
TRANSCRIPT_API_KEY=your-api-key \
uv run -m pytest -m model_api --no-cov tests/test_model_api_transcript.py
# Or with completely clean options:
uv run -m pytest -m gpu_modal -o addopts="" tests/
uv run -m pytest -m model_api -o addopts="" tests/
Running Modal locally for testing:
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
@@ -40,14 +40,16 @@ def get_modal_transcript_url():
url = os.environ.get("TRANSCRIPT_URL")
if not url:
pytest.skip(
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
"TRANSCRIPT_URL environment variable is required for Model API tests"
)
return url
def get_auth_headers():
"""Get authentication headers if API key is available."""
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
api_key = os.environ.get("TRANSCRIPT_API_KEY") or os.environ.get(
"REFLECTOR_GPU_APIKEY"
)
if api_key:
return {"Authorization": f"Bearer {api_key}"}
return {}
@@ -58,8 +60,8 @@ def get_model_name():
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
@pytest.mark.gpu_modal
class TestGPUModalTranscript:
@pytest.mark.model_api
class TestModelAPITranscript:
"""Test suite for GPU Modal transcription endpoints."""
def test_transcriptions_from_url(self):

View File

@@ -0,0 +1,56 @@
"""
Tests for translation Model API endpoint (self-hosted service compatible shape).
Marked with the "model_api" marker and skipped unless TRANSLATION_URL is provided
or we fallback to TRANSCRIPT_URL base (same host for self-hosted).
Run locally against self-hosted server:
TRANSLATION_API_KEY=dev-key \
TRANSLATION_URL=http://localhost:8000 \
uv run -m pytest -m model_api --no-cov tests/test_model_api_translation.py
"""
import os
import httpx
import pytest
def get_translation_url():
url = os.environ.get("TRANSLATION_URL") or os.environ.get("TRANSCRIPT_URL")
if not url:
pytest.skip(
"TRANSLATION_URL or TRANSCRIPT_URL environment variable is required for Model API tests"
)
return url
def get_auth_headers():
api_key = os.environ.get("TRANSLATION_API_KEY") or os.environ.get(
"REFLECTOR_GPU_APIKEY"
)
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
@pytest.mark.model_api
class TestModelAPITranslation:
def test_translate_text(self):
url = get_translation_url()
headers = get_auth_headers()
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{url}/translate",
params={"text": "The meeting will start in five minutes."},
json={"source_language": "en", "target_language": "fr"},
headers=headers,
)
assert response.status_code == 200, f"Request failed: {response.text}"
data = response.json()
assert "text" in data and isinstance(data["text"], dict)
assert data["text"].get("en") == "The meeting will start in five minutes."
assert isinstance(data["text"].get("fr", ""), str)
assert len(data["text"]["fr"]) > 0
assert data["text"]["fr"] == "La réunion commencera dans cinq minutes."