deployment fix (#364)

This commit is contained in:
projects-g
2024-06-20 12:07:28 +05:30
committed by GitHub
parent 76e9842fa8
commit 06b0abaf62
6 changed files with 3 additions and 3 deletions

View File

@@ -0,0 +1,81 @@
# Reflector GPU implementation - Transcription and LLM
This repository hold an API for the GPU implementation of the Reflector API service,
and use [Modal.com](https://modal.com)
- `reflector_llm.py` - LLM API
- `reflector_transcriber.py` - Transcription API
## Modal.com deployment
Create a modal secret, and name it `reflector-gpu`.
It should contain an `REFLECTOR_APIKEY` environment variable with a value.
The deployment is done using [Modal.com](https://modal.com) service.
```
$ modal deploy reflector_transcriber.py
...
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
$ modal deploy reflector_llm.py
...
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
```
Then in your reflector api configuration `.env`, you can set theses keys:
```
TRANSCRIPT_BACKEND=modal
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
LLM_BACKEND=modal
LLM_URL=https://xxxx--reflector-llm-web.modal.run
LLM_MODAL_API_KEY=REFLECTOR_APIKEY
```
## API
Authentication must be passed with the `Authorization` header, using the `bearer` scheme.
```
Authorization: bearer <REFLECTOR_APIKEY>
```
### LLM
`POST /llm`
**request**
```
{
"prompt": "xxx"
}
```
**response**
```
{
"text": "xxx completed"
}
```
### Transcription
`POST /transcribe`
**request** (multipart/form-data)
- `file` - audio file
- `language` - language code (e.g. `en`)
**response**
```
{
"text": "xxx",
"words": [
{"text": "xxx", "start": 0.0, "end": 1.0}
]
}
```

View File

@@ -0,0 +1,188 @@
"""
Reflector GPU backend - diarizer
===================================
"""
import os
import modal.gpu
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0"
MODEL_DIR = "/root/diarization_models"
stub = Stub(name="reflector-diarizer")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
print("LLM cache moved")
def download_pyannote_audio():
from pyannote.audio import Pipeline
Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
cache_dir=MODEL_DIR,
use_auth_token=os.environ["HF_TOKEN"]
)
diarizer_image = (
Image.debian_slim(python_version="3.10.8")
.pip_install(
"pyannote.audio",
"requests",
"onnx",
"torchaudio",
"onnxruntime-gpu",
"torch==2.0.0",
"transformers==4.34.0",
"sentencepiece",
"protobuf",
"numpy",
"huggingface_hub",
"hf-transfer"
)
.run_function(migrate_cache_llm)
.run_function(download_pyannote_audio, secrets=[modal.Secret.from_name("my-huggingface-secret")])
.env(
{
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
)
}
)
)
@stub.cls(
gpu=modal.gpu.A100(memory=40),
timeout=60 * 30,
container_idle_timeout=60,
allow_concurrent_inputs=1,
image=diarizer_image,
secrets=[modal.Secret.from_name("my-huggingface-secret")],
)
class Diarizer:
def __enter__(self):
import torch
from pyannote.audio import Pipeline
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
cache_dir=MODEL_DIR
)
self.diarization_pipeline.to(torch.device(self.device))
@method()
def diarize(
self,
audio_data: str,
audio_suffix: str,
timestamp: float
):
import tempfile
import torchaudio
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
print("Diarizing audio")
waveform, sample_rate = torchaudio.load(fp.name)
diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
words = []
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
words.append(
{
"start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:])
}
)
print("Diarization complete")
return {
"diarization": words
}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
timeout=60 * 10,
container_idle_timeout=60 * 3,
allow_concurrent_inputs=40,
secrets=[
Secret.from_name("reflector-gpu"),
],
image=diarizer_image
)
@asgi_app()
def web():
import requests
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
diarizerstub = Diarizer()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
def validate_audio_file(audio_file_url: str):
# Check if the audio file exists
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(
status_code=response.status_code,
detail="The audio file does not exist."
)
class DiarizationResponse(BaseModel):
result: dict
@app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)])
def diarize(
audio_file_url: str,
timestamp: float = 0.0
) -> HTTPException | DiarizationResponse:
# Currently the uploaded files are in mp3 format
audio_suffix = "mp3"
print("Downloading audio file")
response = requests.get(audio_file_url, allow_redirects=True)
print("Audio file downloaded successfully")
func = diarizerstub.diarize.spawn(
audio_data=response.content,
audio_suffix=audio_suffix,
timestamp=timestamp
)
result = func.get()
return result
return app

View File

@@ -0,0 +1,206 @@
"""
Reflector GPU backend - LLM
===========================
"""
import json
import os
import threading
from typing import Optional
import modal
from modal import Image, Secret, Stub, asgi_app, method
# LLM
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300
IMAGE_MODEL_DIR = "/root/llm_models"
stub = Stub(name="reflector-llm")
def download_llm():
from huggingface_hub import snapshot_download
print("Downloading LLM model")
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
print("LLM model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
print("LLM cache moved")
llm_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.pip_install(
"transformers",
"torch",
"sentencepiece",
"protobuf",
"jsonformer==0.12.0",
"accelerate==0.21.0",
"einops==0.6.1",
"hf-transfer~=0.1",
"huggingface_hub==0.16.4"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_llm)
.run_function(migrate_cache_llm)
)
@stub.cls(
gpu="A100",
timeout=60 * 5,
container_idle_timeout=60 * 5,
allow_concurrent_inputs=15,
image=llm_image,
)
class LLM:
def __enter__(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
print("Instance llm model")
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# JSONFormer doesn't yet support generation configs
print("Instance llm generation config")
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
# generation configuration
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
# load tokenizer
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# move model to gpu
print("Move llm model to GPU")
model = model.cuda()
print("Warmup llm done")
self.model = model
self.tokenizer = tokenizer
self.gen_cfg = gen_cfg
self.GenerationConfig = GenerationConfig
self.lock = threading.Lock()
def __exit__(self, *args):
print("Exit llm")
@method()
def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict:
"""
Perform a generation action using the LLM
"""
print(f"Generate {prompt=}")
if gen_cfg:
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
else:
gen_cfg = self.gen_cfg
# If a gen_schema is given, conform to gen_schema
with self.lock:
if gen_schema:
import jsonformer
print(f"Schema {gen_schema=}")
jsonformer_llm = jsonformer.Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=json.loads(gen_schema),
prompt=prompt,
max_string_token_length=gen_cfg.max_new_tokens
)
response = jsonformer_llm()
else:
# If no gen_schema, perform prompt only generation
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
response = response[len(prompt):]
print(f"Generated {response=}")
return {"text": response}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
container_idle_timeout=60 * 10,
timeout=60 * 5,
allow_concurrent_inputs=45,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
llmstub = LLM()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class LLMRequest(BaseModel):
prompt: str
gen_schema: Optional[dict] = None
gen_cfg: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)])
def llm(
req: LLMRequest,
):
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg)
result = func.get()
return result
return app

View File

@@ -0,0 +1,214 @@
"""
Reflector GPU backend - LLM
===========================
"""
import json
import os
import threading
from typing import Optional
import modal
from modal import Image, Secret, Stub, asgi_app, method
# LLM
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
stub = Stub(name="reflector-llm-zephyr")
def download_llm():
from huggingface_hub import snapshot_download
print("Downloading LLM model")
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
print("LLM model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
print("LLM cache moved")
llm_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.pip_install(
"transformers==4.34.0",
"torch",
"sentencepiece",
"protobuf",
"jsonformer==0.12.0",
"accelerate==0.21.0",
"einops==0.6.1",
"hf-transfer~=0.1",
"huggingface_hub==0.16.4"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_llm)
.run_function(migrate_cache_llm)
)
@stub.cls(
gpu="A10G",
timeout=60 * 5,
container_idle_timeout=60 * 5,
allow_concurrent_inputs=10,
image=llm_image,
)
class LLM:
def __enter__(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
print("Instance llm model")
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
# JSONFormer doesn't yet support generation configs
print("Instance llm generation config")
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
# generation configuration
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
# load tokenizer
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True
)
gen_cfg.pad_token_id = tokenizer.eos_token_id
gen_cfg.eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# move model to gpu
print("Move llm model to GPU")
model = model.cuda()
print("Warmup llm done")
self.model = model
self.tokenizer = tokenizer
self.gen_cfg = gen_cfg
self.GenerationConfig = GenerationConfig
self.lock = threading.Lock()
def __exit__(self, *args):
print("Exit llm")
@method()
def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict:
"""
Perform a generation action using the LLM
"""
print(f"Generate {prompt=}")
if gen_cfg:
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
gen_cfg.pad_token_id = self.tokenizer.eos_token_id
gen_cfg.eos_token_id = self.tokenizer.eos_token_id
else:
gen_cfg = self.gen_cfg
# If a gen_schema is given, conform to gen_schema
with self.lock:
if gen_schema:
import jsonformer
print(f"Schema {gen_schema=}")
jsonformer_llm = jsonformer.Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=json.loads(gen_schema),
prompt=prompt,
max_string_token_length=gen_cfg.max_new_tokens
)
response = jsonformer_llm()
else:
# If no gen_schema, perform prompt only generation
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
response = response[len(prompt):]
response = {
"long_summary": response
}
print(f"Generated {response=}")
return {"text": response}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
container_idle_timeout=60 * 10,
timeout=60 * 5,
allow_concurrent_inputs=30,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
llmstub = LLM()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class LLMRequest(BaseModel):
prompt: str
gen_schema: Optional[dict] = None
gen_cfg: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)])
def llm(
req: LLMRequest,
):
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg)
result = func.get()
return result
return app

View File

@@ -0,0 +1,203 @@
"""
Reflector GPU backend - transcriber
===================================
"""
import os
import tempfile
import threading
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
# Whisper
WHISPER_MODEL: str = "large-v2"
WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1
WHISPER_MODEL_DIR = "/root/transcription_models"
stub = Stub(name="reflector-transcriber")
def download_whisper():
from faster_whisper.utils import download_model
print("Downloading Whisper model")
download_model(WHISPER_MODEL, cache_dir=WHISPER_MODEL_DIR)
print("Whisper model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=WHISPER_MODEL_DIR, new_cache_dir=WHISPER_MODEL_DIR)
print("LLM cache moved")
transcriber_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.apt_install("wget")
.apt_install("libsndfile-dev")
.pip_install(
"faster-whisper",
"requests",
"torch",
"transformers==4.34.0",
"sentencepiece",
"protobuf",
"huggingface_hub==0.16.4",
"gitpython",
"torchaudio",
"fairseq2",
"pyyaml",
"hf-transfer~=0.1"
)
.run_function(download_whisper)
.run_function(migrate_cache_llm)
.env(
{
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
)
}
)
)
@stub.cls(
gpu="A10G",
timeout=60 * 5,
container_idle_timeout=60 * 5,
allow_concurrent_inputs=6,
image=transcriber_image,
)
class Transcriber:
def __enter__(self):
import faster_whisper
import torch
self.lock = threading.Lock()
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.model = faster_whisper.WhisperModel(
WHISPER_MODEL,
device=self.device,
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
download_root=WHISPER_MODEL_DIR,
local_files_only=True
)
@method()
def transcribe_segment(
self,
audio_data: str,
audio_suffix: str,
source_language: str,
timestamp: float = 0
):
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
with self.lock:
segments, _ = self.model.transcribe(
fp.name,
language=source_language,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
multilingual_transcript = {}
transcript_source_lang = ""
words = []
if segments:
segments = list(segments)
for segment in segments:
transcript_source_lang += segment.text
for word in segment.words:
words.append(
{
"text": word.word,
"start": round(timestamp + word.start, 3),
"end": round(timestamp + word.end, 3),
}
)
multilingual_transcript[source_language] = transcript_source_lang
return {
"text": multilingual_transcript,
"words": words
}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
container_idle_timeout=60,
timeout=60,
allow_concurrent_inputs=40,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, status
from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated
transcriberstub = Transcriber()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
supported_audio_file_types = ["wav", "mp3", "ogg", "flac"]
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class TranscriptResponse(BaseModel):
result: dict
@app.post("/transcribe", dependencies=[Depends(apikey_auth)])
def transcribe(
file: UploadFile,
source_language: Annotated[str, Body(...)] = "en",
timestamp: Annotated[float, Body()] = 0.0
) -> TranscriptResponse:
audio_data = file.file.read()
audio_suffix = file.filename.split(".")[-1]
assert audio_suffix in supported_audio_file_types
func = transcriberstub.transcribe_segment.spawn(
audio_data=audio_data,
audio_suffix=audio_suffix,
source_language=source_language,
timestamp=timestamp
)
result = func.get()
return result
return app

View File

@@ -0,0 +1,427 @@
"""
Reflector GPU backend - transcriber
===================================
"""
import os
import threading
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
# Seamless M4T
SEAMLESSM4T_MODEL_SIZE: str = "medium"
SEAMLESSM4T_MODEL_CARD_NAME: str = f"seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}"
SEAMLESSM4T_VOCODER_CARD_NAME: str = "vocoder_36langs"
HF_SEAMLESS_M4TEPO: str = f"facebook/seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}"
HF_SEAMLESS_M4T_VOCODEREPO: str = "facebook/seamless-m4t-vocoder"
SEAMLESS_GITEPO: str = "https://github.com/facebookresearch/seamless_communication.git"
SEAMLESS_MODEL_DIR: str = "m4t"
stub = Stub(name="reflector-translator")
def install_seamless_communication():
import os
import subprocess
initial_dir = os.getcwd()
subprocess.run(
["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"]
)
subprocess.run(["rm", "-rf", "seamless_communication"])
subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"])
os.chdir("seamless_communication")
subprocess.run(["pip", "install", "-e", "."])
os.chdir(initial_dir)
def download_seamlessm4t_model():
from huggingface_hub import snapshot_download
print("Downloading Transcriber model & tokenizer")
snapshot_download(HF_SEAMLESS_M4TEPO, cache_dir=SEAMLESS_MODEL_DIR)
print("Transcriber model & tokenizer downloaded")
print("Downloading vocoder weights")
snapshot_download(HF_SEAMLESS_M4T_VOCODEREPO, cache_dir=SEAMLESS_MODEL_DIR)
print("Vocoder weights downloaded")
def configure_seamless_m4t():
import os
import yaml
CARDS_DIR: str = "./seamless_communication/src/seamless_communication/cards"
with open(f"{CARDS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "r") as file:
model_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f"{CARDS_DIR}/vocoder_36langs.yaml", "r") as file:
vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f"{CARDS_DIR}/unity_nllb-100.yaml", "r") as file:
unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f"{CARDS_DIR}/unity_nllb-200.yaml", "r") as file:
unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots"
available_model_versions = os.listdir(model_dir)
latest_model_version = sorted(available_model_versions)[-1]
model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt"
model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name)
vocoder_dir = (
f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots"
)
available_vocoder_versions = os.listdir(vocoder_dir)
latest_vocoder_version = sorted(available_vocoder_versions)[-1]
vocoder_name = "vocoder_36langs.pt"
vocoder_path = os.path.join(
os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name
)
tokenizer_name = "tokenizer.model"
tokenizer_path = os.path.join(
os.getcwd(), model_dir, latest_model_version, tokenizer_name
)
model_yaml_data["checkpoint"] = f"file://{model_path}"
vocoder_yaml_data["checkpoint"] = f"file://{vocoder_path}"
unity_100_yaml_data["tokenizer"] = f"file://{tokenizer_path}"
unity_200_yaml_data["tokenizer"] = f"file://{tokenizer_path}"
with open(f"{CARDS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "w") as file:
yaml.dump(model_yaml_data, file)
with open(f"{CARDS_DIR}/vocoder_36langs.yaml", "w") as file:
yaml.dump(vocoder_yaml_data, file)
with open(f"{CARDS_DIR}/unity_nllb-100.yaml", "w") as file:
yaml.dump(unity_100_yaml_data, file)
with open(f"{CARDS_DIR}/unity_nllb-200.yaml", "w") as file:
yaml.dump(unity_200_yaml_data, file)
transcriber_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.apt_install("wget")
.apt_install("libsndfile-dev")
.pip_install(
"requests",
"torch",
"transformers==4.34.0",
"sentencepiece",
"protobuf",
"huggingface_hub==0.16.4",
"gitpython",
"torchaudio",
"fairseq2",
"pyyaml",
"hf-transfer~=0.1",
)
.run_function(install_seamless_communication)
.run_function(download_seamlessm4t_model)
.run_function(configure_seamless_m4t)
.env(
{
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
)
}
)
)
@stub.cls(
gpu="A10G",
timeout=60 * 5,
container_idle_timeout=60 * 5,
allow_concurrent_inputs=4,
image=transcriber_image,
)
class Translator:
def __enter__(self):
import torch
from seamless_communication.inference.translator import Translator
self.lock = threading.Lock()
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.translator = Translator(
SEAMLESSM4T_MODEL_CARD_NAME,
SEAMLESSM4T_VOCODER_CARD_NAME,
torch.device(self.device),
dtype=torch.float32,
)
@method()
def warmup(self):
return {"status": "ok"}
def get_seamless_lang_code(self, lang_code: str):
"""
The codes for SeamlessM4T is different from regular standards.
For ex, French is "fra" and not "fr".
"""
# TODO: Enhance with complete list of lang codes
seamless_lang_code = {
# Afrikaans
'af': 'afr',
# Amharic
'am': 'amh',
# Modern Standard Arabic
'ar': 'arb',
# Moroccan Arabic
'ary': 'ary',
# Egyptian Arabic
'arz': 'arz',
# Assamese
'as': 'asm',
# North Azerbaijani
'az': 'azj',
# Belarusian
'be': 'bel',
# Bengali
'bn': 'ben',
# Bosnian
'bs': 'bos',
# Bulgarian
'bg': 'bul',
# Catalan
'ca': 'cat',
# Cebuano
'ceb': 'ceb',
# Czech
'cs': 'ces',
# Central Kurdish
'ku': 'ckb',
# Mandarin Chinese
'cmn': 'cmn_Hant',
# Welsh
'cy': 'cym',
# Danish
'da': 'dan',
# German
'de': 'deu',
# Greek
'el': 'ell',
# English
'en': 'eng',
# Estonian
'et': 'est',
# Basque
'eu': 'eus',
# Finnish
'fi': 'fin',
# French
'fr': 'fra',
# Irish
'ga': 'gle',
# West Central Oromo,
'gaz': 'gaz',
# Galician
'gl': 'glg',
# Gujarati
'gu': 'guj',
# Hebrew
'he': 'heb',
# Hindi
'hi': 'hin',
# Croatian
'hr': 'hrv',
# Hungarian
'hu': 'hun',
# Armenian
'hy': 'hye',
# Igbo
'ig': 'ibo',
# Indonesian
'id': 'ind',
# Icelandic
'is': 'isl',
# Italian
'it': 'ita',
# Javanese
'jv': 'jav',
# Japanese
'ja': 'jpn',
# Kannada
'kn': 'kan',
# Georgian
'ka': 'kat',
# Kazakh
'kk': 'kaz',
# Halh Mongolian
'khk': 'khk',
# Khmer
'km': 'khm',
# Kyrgyz
'ky': 'kir',
# Korean
'ko': 'kor',
# Lao
'lo': 'lao',
# Lithuanian
'lt': 'lit',
# Ganda
'lg': 'lug',
# Luo
'luo': 'luo',
# Standard Latvian
'lv': 'lvs',
# Maithili
'mai': 'mai',
# Malayalam
'ml': 'mal',
# Marathi
'mr': 'mar',
# Macedonian
'mk': 'mkd',
# Maltese
'mt': 'mlt',
# Meitei
'mni': 'mni',
# Burmese
'my': 'mya',
# Dutch
'nl': 'nld',
# Norwegian Nynorsk
'nn': 'nno',
# Norwegian Bokmål
'nb': 'nob',
# Nepali
'ne': 'npi',
# Nyanja
'ny': 'nya',
# Odia
'or': 'ory',
# Punjabi
'pa': 'pan',
# Southern Pashto
'pbt': 'pbt',
# Western Persian
'pes': 'pes',
# Polish
'pl': 'pol',
# Portuguese
'pt': 'por',
# Romanian
'ro': 'ron',
# Russian
'ru': 'rus',
# Slovak
'sk': 'slk',
# Slovenian
'sl': 'slv',
# Shona
'sn': 'sna',
# Sindhi
'sd': 'snd',
# Somali
'so': 'som',
# Spanish
'es': 'spa',
# Serbian
'sr': 'srp',
# Swedish
'sv': 'swe',
# Swahili
'sw': 'swh',
# Tamil
'ta': 'tam',
# Telugu
'te': 'tel',
# Tajik
'tg': 'tgk',
# Tagalog
'tl': 'tgl',
# Thai
'th': 'tha',
# Turkish
'tr': 'tur',
# Ukrainian
'uk': 'ukr',
# Urdu
'ur': 'urd',
# Northern Uzbek
'uz': 'uzn',
# Vietnamese
'vi': 'vie',
# Yoruba
'yo': 'yor',
# Cantonese
'yue': 'yue',
# Standard Malay
'ms': 'zsm',
# Zulu
'zu': 'zul'
}
return seamless_lang_code.get(lang_code, "eng")
@method()
def translate_text(self, text: str, source_language: str, target_language: str):
with self.lock:
translation_result, _ = self.translator.predict(
text,
"t2tt",
src_lang=self.get_seamless_lang_code(source_language),
tgt_lang=self.get_seamless_lang_code(target_language),
unit_generation_ngram_filtering=True,
)
translated_text = str(translation_result[0])
return {"text": {source_language: text, target_language: translated_text}}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
container_idle_timeout=60,
timeout=60,
allow_concurrent_inputs=40,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Body, Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated
translatorstub = Translator()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class TranslateResponse(BaseModel):
result: dict
@app.post("/translate", dependencies=[Depends(apikey_auth)])
async def translate(
text: str,
source_language: Annotated[str, Body(...)] = "en",
target_language: Annotated[str, Body(...)] = "fr",
) -> TranslateResponse:
func = translatorstub.translate_text.spawn(
text=text,
source_language=source_language,
target_language=target_language,
)
result = func.get()
return result
return app