mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
* Documents transcriber api * Update whisper transcriber api to match parakeet * Update api transcription spec * Return 400 for unsupported file type * Add params to api spec * Update whisper transcriber implementation to match parakeet
609 lines
19 KiB
Python
609 lines
19 KiB
Python
import os
|
|
import sys
|
|
import threading
|
|
import uuid
|
|
from typing import Generator, Mapping, NamedTuple, NewType, TypedDict
|
|
from urllib.parse import urlparse
|
|
|
|
import modal
|
|
|
|
MODEL_NAME = "large-v2"
|
|
MODEL_COMPUTE_TYPE: str = "float16"
|
|
MODEL_NUM_WORKERS: int = 1
|
|
MINUTES = 60 # seconds
|
|
SAMPLERATE = 16000
|
|
UPLOADS_PATH = "/uploads"
|
|
CACHE_PATH = "/models"
|
|
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
|
VAD_CONFIG = {
|
|
"batch_max_duration": 30.0,
|
|
"silence_padding": 0.5,
|
|
"window_size": 512,
|
|
}
|
|
|
|
|
|
WhisperUniqFilename = NewType("WhisperUniqFilename", str)
|
|
AudioFileExtension = NewType("AudioFileExtension", str)
|
|
|
|
app = modal.App("reflector-transcriber")
|
|
|
|
model_cache = modal.Volume.from_name("models", create_if_missing=True)
|
|
upload_volume = modal.Volume.from_name("whisper-uploads", create_if_missing=True)
|
|
|
|
|
|
class TimeSegment(NamedTuple):
|
|
"""Represents a time segment with start and end times."""
|
|
|
|
start: float
|
|
end: float
|
|
|
|
|
|
class AudioSegment(NamedTuple):
|
|
"""Represents an audio segment with timing and audio data."""
|
|
|
|
start: float
|
|
end: float
|
|
audio: any
|
|
|
|
|
|
class TranscriptResult(NamedTuple):
|
|
"""Represents a transcription result with text and word timings."""
|
|
|
|
text: str
|
|
words: list["WordTiming"]
|
|
|
|
|
|
class WordTiming(TypedDict):
|
|
"""Represents a word with its timing information."""
|
|
|
|
word: str
|
|
start: float
|
|
end: float
|
|
|
|
|
|
def download_model():
|
|
from faster_whisper import download_model
|
|
|
|
model_cache.reload()
|
|
|
|
download_model(MODEL_NAME, cache_dir=CACHE_PATH)
|
|
|
|
model_cache.commit()
|
|
|
|
|
|
image = (
|
|
modal.Image.debian_slim(python_version="3.12")
|
|
.env(
|
|
{
|
|
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
"LD_LIBRARY_PATH": (
|
|
"/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/:"
|
|
"/opt/conda/lib/python3.12/site-packages/nvidia/cublas/lib/"
|
|
),
|
|
}
|
|
)
|
|
.apt_install("ffmpeg")
|
|
.pip_install(
|
|
"huggingface_hub==0.27.1",
|
|
"hf-transfer==0.1.9",
|
|
"torch==2.5.1",
|
|
"faster-whisper==1.1.1",
|
|
"fastapi==0.115.12",
|
|
"requests",
|
|
"librosa==0.10.1",
|
|
"numpy<2",
|
|
"silero-vad==5.1.0",
|
|
)
|
|
.run_function(download_model, volumes={CACHE_PATH: model_cache})
|
|
)
|
|
|
|
|
|
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
|
parsed_url = urlparse(url)
|
|
url_path = parsed_url.path
|
|
|
|
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
if url_path.lower().endswith(f".{ext}"):
|
|
return AudioFileExtension(ext)
|
|
|
|
content_type = headers.get("content-type", "").lower()
|
|
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
return AudioFileExtension("mp3")
|
|
if "audio/wav" in content_type:
|
|
return AudioFileExtension("wav")
|
|
if "audio/mp4" in content_type:
|
|
return AudioFileExtension("mp4")
|
|
|
|
raise ValueError(
|
|
f"Unsupported audio format for URL: {url}. "
|
|
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
)
|
|
|
|
|
|
def download_audio_to_volume(
|
|
audio_file_url: str,
|
|
) -> tuple[WhisperUniqFilename, AudioFileExtension]:
|
|
import requests
|
|
from fastapi import HTTPException
|
|
|
|
response = requests.head(audio_file_url, allow_redirects=True)
|
|
if response.status_code == 404:
|
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
|
|
response = requests.get(audio_file_url, allow_redirects=True)
|
|
response.raise_for_status()
|
|
|
|
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
unique_filename = WhisperUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(response.content)
|
|
|
|
upload_volume.commit()
|
|
return unique_filename, audio_suffix
|
|
|
|
|
|
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
|
"""Add 0.5s of silence if audio is shorter than the silence_padding window.
|
|
|
|
Whisper does not require this strictly, but aligning behavior with Parakeet
|
|
avoids edge-case crashes on extremely short inputs and makes comparisons easier.
|
|
"""
|
|
import numpy as np
|
|
|
|
audio_duration = len(audio_array) / sample_rate
|
|
if audio_duration < VAD_CONFIG["silence_padding"]:
|
|
silence_samples = int(sample_rate * VAD_CONFIG["silence_padding"])
|
|
silence = np.zeros(silence_samples, dtype=np.float32)
|
|
return np.concatenate([audio_array, silence])
|
|
return audio_array
|
|
|
|
|
|
@app.cls(
|
|
gpu="A10G",
|
|
timeout=5 * MINUTES,
|
|
scaledown_window=5 * MINUTES,
|
|
image=image,
|
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
)
|
|
@modal.concurrent(max_inputs=10)
|
|
class TranscriberWhisperLive:
|
|
"""Live transcriber class for small audio segments (A10G).
|
|
|
|
Mirrors the Parakeet live class API but uses Faster-Whisper under the hood.
|
|
"""
|
|
|
|
@modal.enter()
|
|
def enter(self):
|
|
import faster_whisper
|
|
import torch
|
|
|
|
self.lock = threading.Lock()
|
|
self.use_gpu = torch.cuda.is_available()
|
|
self.device = "cuda" if self.use_gpu else "cpu"
|
|
self.model = faster_whisper.WhisperModel(
|
|
MODEL_NAME,
|
|
device=self.device,
|
|
compute_type=MODEL_COMPUTE_TYPE,
|
|
num_workers=MODEL_NUM_WORKERS,
|
|
download_root=CACHE_PATH,
|
|
local_files_only=True,
|
|
)
|
|
print(f"Model is on device: {self.device}")
|
|
|
|
@modal.method()
|
|
def transcribe_segment(
|
|
self,
|
|
filename: str,
|
|
language: str = "en",
|
|
):
|
|
"""Transcribe a single uploaded audio file by filename."""
|
|
upload_volume.reload()
|
|
|
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|
|
|
with self.lock:
|
|
with NoStdStreams():
|
|
segments, _ = self.model.transcribe(
|
|
file_path,
|
|
language=language,
|
|
beam_size=5,
|
|
word_timestamps=True,
|
|
vad_filter=True,
|
|
vad_parameters={"min_silence_duration_ms": 500},
|
|
)
|
|
|
|
segments = list(segments)
|
|
text = "".join(segment.text for segment in segments).strip()
|
|
words = [
|
|
{
|
|
"word": word.word,
|
|
"start": round(float(word.start), 2),
|
|
"end": round(float(word.end), 2),
|
|
}
|
|
for segment in segments
|
|
for word in segment.words
|
|
]
|
|
|
|
return {"text": text, "words": words}
|
|
|
|
@modal.method()
|
|
def transcribe_batch(
|
|
self,
|
|
filenames: list[str],
|
|
language: str = "en",
|
|
):
|
|
"""Transcribe multiple uploaded audio files and return per-file results."""
|
|
upload_volume.reload()
|
|
|
|
results = []
|
|
for filename in filenames:
|
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
|
|
|
with self.lock:
|
|
with NoStdStreams():
|
|
segments, _ = self.model.transcribe(
|
|
file_path,
|
|
language=language,
|
|
beam_size=5,
|
|
word_timestamps=True,
|
|
vad_filter=True,
|
|
vad_parameters={"min_silence_duration_ms": 500},
|
|
)
|
|
|
|
segments = list(segments)
|
|
text = "".join(seg.text for seg in segments).strip()
|
|
words = [
|
|
{
|
|
"word": w.word,
|
|
"start": round(float(w.start), 2),
|
|
"end": round(float(w.end), 2),
|
|
}
|
|
for seg in segments
|
|
for w in seg.words
|
|
]
|
|
|
|
results.append(
|
|
{
|
|
"filename": filename,
|
|
"text": text,
|
|
"words": words,
|
|
}
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
@app.cls(
|
|
gpu="L40S",
|
|
timeout=15 * MINUTES,
|
|
image=image,
|
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
)
|
|
class TranscriberWhisperFile:
|
|
"""File transcriber for larger/longer audio, using VAD-driven batching (L40S)."""
|
|
|
|
@modal.enter()
|
|
def enter(self):
|
|
import faster_whisper
|
|
import torch
|
|
from silero_vad import load_silero_vad
|
|
|
|
self.lock = threading.Lock()
|
|
self.use_gpu = torch.cuda.is_available()
|
|
self.device = "cuda" if self.use_gpu else "cpu"
|
|
self.model = faster_whisper.WhisperModel(
|
|
MODEL_NAME,
|
|
device=self.device,
|
|
compute_type=MODEL_COMPUTE_TYPE,
|
|
num_workers=MODEL_NUM_WORKERS,
|
|
download_root=CACHE_PATH,
|
|
local_files_only=True,
|
|
)
|
|
self.vad_model = load_silero_vad(onnx=False)
|
|
|
|
@modal.method()
|
|
def transcribe_segment(
|
|
self, filename: str, timestamp_offset: float = 0.0, language: str = "en"
|
|
):
|
|
import librosa
|
|
import numpy as np
|
|
from silero_vad import VADIterator
|
|
|
|
def vad_segments(
|
|
audio_array,
|
|
sample_rate: int = SAMPLERATE,
|
|
window_size: int = VAD_CONFIG["window_size"],
|
|
) -> Generator[TimeSegment, None, None]:
|
|
"""Generate speech segments as TimeSegment using Silero VAD."""
|
|
iterator = VADIterator(self.vad_model, sampling_rate=sample_rate)
|
|
start = None
|
|
for i in range(0, len(audio_array), window_size):
|
|
chunk = audio_array[i : i + window_size]
|
|
if len(chunk) < window_size:
|
|
chunk = np.pad(
|
|
chunk, (0, window_size - len(chunk)), mode="constant"
|
|
)
|
|
speech = iterator(chunk)
|
|
if not speech:
|
|
continue
|
|
if "start" in speech:
|
|
start = speech["start"]
|
|
continue
|
|
if "end" in speech and start is not None:
|
|
end = speech["end"]
|
|
yield TimeSegment(
|
|
start / float(SAMPLERATE), end / float(SAMPLERATE)
|
|
)
|
|
start = None
|
|
iterator.reset_states()
|
|
|
|
upload_volume.reload()
|
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|
|
|
audio_array, _sr = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
|
|
|
# Batch segments up to ~30s windows by merging contiguous VAD segments
|
|
merged_batches: list[TimeSegment] = []
|
|
batch_start = None
|
|
batch_end = None
|
|
max_duration = VAD_CONFIG["batch_max_duration"]
|
|
for segment in vad_segments(audio_array):
|
|
seg_start, seg_end = segment.start, segment.end
|
|
if batch_start is None:
|
|
batch_start, batch_end = seg_start, seg_end
|
|
continue
|
|
if seg_end - batch_start <= max_duration:
|
|
batch_end = seg_end
|
|
else:
|
|
merged_batches.append(TimeSegment(batch_start, batch_end))
|
|
batch_start, batch_end = seg_start, seg_end
|
|
if batch_start is not None and batch_end is not None:
|
|
merged_batches.append(TimeSegment(batch_start, batch_end))
|
|
|
|
all_text = []
|
|
all_words = []
|
|
|
|
for segment in merged_batches:
|
|
start_time, end_time = segment.start, segment.end
|
|
s_idx = int(start_time * SAMPLERATE)
|
|
e_idx = int(end_time * SAMPLERATE)
|
|
segment = audio_array[s_idx:e_idx]
|
|
segment = pad_audio(segment, SAMPLERATE)
|
|
|
|
with self.lock:
|
|
segments, _ = self.model.transcribe(
|
|
segment,
|
|
language=language,
|
|
beam_size=5,
|
|
word_timestamps=True,
|
|
vad_filter=True,
|
|
vad_parameters={"min_silence_duration_ms": 500},
|
|
)
|
|
|
|
segments = list(segments)
|
|
text = "".join(seg.text for seg in segments).strip()
|
|
words = [
|
|
{
|
|
"word": w.word,
|
|
"start": round(float(w.start) + start_time + timestamp_offset, 2),
|
|
"end": round(float(w.end) + start_time + timestamp_offset, 2),
|
|
}
|
|
for seg in segments
|
|
for w in seg.words
|
|
]
|
|
if text:
|
|
all_text.append(text)
|
|
all_words.extend(words)
|
|
|
|
return {"text": " ".join(all_text), "words": all_words}
|
|
|
|
|
|
def detect_audio_format(url: str, headers: dict) -> str:
|
|
from urllib.parse import urlparse
|
|
|
|
from fastapi import HTTPException
|
|
|
|
url_path = urlparse(url).path
|
|
for ext in SUPPORTED_FILE_EXTENSIONS:
|
|
if url_path.lower().endswith(f".{ext}"):
|
|
return ext
|
|
|
|
content_type = headers.get("content-type", "").lower()
|
|
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
|
return "mp3"
|
|
if "audio/wav" in content_type:
|
|
return "wav"
|
|
if "audio/mp4" in content_type:
|
|
return "mp4"
|
|
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=(
|
|
f"Unsupported audio format for URL. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
),
|
|
)
|
|
|
|
|
|
def download_audio_to_volume(audio_file_url: str) -> tuple[str, str]:
|
|
import requests
|
|
from fastapi import HTTPException
|
|
|
|
response = requests.head(audio_file_url, allow_redirects=True)
|
|
if response.status_code == 404:
|
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
|
|
response = requests.get(audio_file_url, allow_redirects=True)
|
|
response.raise_for_status()
|
|
|
|
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
|
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(response.content)
|
|
|
|
upload_volume.commit()
|
|
return unique_filename, audio_suffix
|
|
|
|
|
|
@app.function(
|
|
scaledown_window=60,
|
|
timeout=600,
|
|
secrets=[
|
|
modal.Secret.from_name("reflector-gpu"),
|
|
],
|
|
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
|
image=image,
|
|
)
|
|
@modal.concurrent(max_inputs=40)
|
|
@modal.asgi_app()
|
|
def web():
|
|
from fastapi import (
|
|
Body,
|
|
Depends,
|
|
FastAPI,
|
|
Form,
|
|
HTTPException,
|
|
UploadFile,
|
|
status,
|
|
)
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
|
transcriber_live = TranscriberWhisperLive()
|
|
transcriber_file = TranscriberWhisperFile()
|
|
|
|
app = FastAPI()
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
|
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
|
return
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
class TranscriptResponse(dict):
|
|
pass
|
|
|
|
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
|
def transcribe(
|
|
file: UploadFile = None,
|
|
files: list[UploadFile] | None = None,
|
|
model: str = Form(MODEL_NAME),
|
|
language: str = Form("en"),
|
|
batch: bool = Form(False),
|
|
):
|
|
if not file and not files:
|
|
raise HTTPException(
|
|
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
|
)
|
|
if batch and not files:
|
|
raise HTTPException(
|
|
status_code=400, detail="Batch transcription requires 'files'"
|
|
)
|
|
|
|
upload_files = [file] if file else files
|
|
|
|
uploaded_filenames: list[str] = []
|
|
for upload_file in upload_files:
|
|
audio_suffix = upload_file.filename.split(".")[-1]
|
|
if audio_suffix not in SUPPORTED_FILE_EXTENSIONS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=(
|
|
f"Unsupported audio format. Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
|
),
|
|
)
|
|
|
|
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
with open(file_path, "wb") as f:
|
|
content = upload_file.file.read()
|
|
f.write(content)
|
|
uploaded_filenames.append(unique_filename)
|
|
|
|
upload_volume.commit()
|
|
|
|
try:
|
|
if batch and len(upload_files) > 1:
|
|
func = transcriber_live.transcribe_batch.spawn(
|
|
filenames=uploaded_filenames,
|
|
language=language,
|
|
)
|
|
results = func.get()
|
|
return {"results": results}
|
|
|
|
results = []
|
|
for filename in uploaded_filenames:
|
|
func = transcriber_live.transcribe_segment.spawn(
|
|
filename=filename,
|
|
language=language,
|
|
)
|
|
result = func.get()
|
|
result["filename"] = filename
|
|
results.append(result)
|
|
|
|
return {"results": results} if len(results) > 1 else results[0]
|
|
finally:
|
|
for filename in uploaded_filenames:
|
|
try:
|
|
file_path = f"{UPLOADS_PATH}/{filename}"
|
|
os.remove(file_path)
|
|
except Exception:
|
|
pass
|
|
upload_volume.commit()
|
|
|
|
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
|
def transcribe_from_url(
|
|
audio_file_url: str = Body(
|
|
..., description="URL of the audio file to transcribe"
|
|
),
|
|
model: str = Body(MODEL_NAME),
|
|
language: str = Body("en"),
|
|
timestamp_offset: float = Body(0.0),
|
|
):
|
|
unique_filename, _audio_suffix = download_audio_to_volume(audio_file_url)
|
|
try:
|
|
func = transcriber_file.transcribe_segment.spawn(
|
|
filename=unique_filename,
|
|
timestamp_offset=timestamp_offset,
|
|
language=language,
|
|
)
|
|
result = func.get()
|
|
return result
|
|
finally:
|
|
try:
|
|
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
|
os.remove(file_path)
|
|
upload_volume.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
return app
|
|
|
|
|
|
class NoStdStreams:
|
|
def __init__(self):
|
|
self.devnull = open(os.devnull, "w")
|
|
|
|
def __enter__(self):
|
|
self._stdout, self._stderr = sys.stdout, sys.stderr
|
|
self._stdout.flush()
|
|
self._stderr.flush()
|
|
sys.stdout, sys.stderr = self.devnull, self.devnull
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
sys.stdout, sys.stderr = self._stdout, self._stderr
|
|
self.devnull.close()
|