gpu: update format + list of country 2 to 3

This commit is contained in:
2023-10-13 23:33:37 +02:00
parent 6c1869b79a
commit 9269db74c0

View File

@@ -26,8 +26,11 @@ stub = Stub(name="reflector-translator")
def install_seamless_communication():
import os
import subprocess
initial_dir = os.getcwd()
subprocess.run(["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"])
subprocess.run(
["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"]
)
subprocess.run(["rm", "-rf", "seamless_communication"])
subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"])
os.chdir("seamless_communication")
@@ -54,13 +57,13 @@ def configure_seamless_m4t():
ASSETS_DIR: str = "./seamless_communication/src/seamless_communication/assets/cards"
with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'r') as file:
with open(f"{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "r") as file:
model_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'r') as file:
with open(f"{ASSETS_DIR}/vocoder_36langs.yaml", "r") as file:
vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'r') as file:
with open(f"{ASSETS_DIR}/unity_nllb-100.yaml", "r") as file:
unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'r') as file:
with open(f"{ASSETS_DIR}/unity_nllb-200.yaml", "r") as file:
unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots"
@@ -69,27 +72,33 @@ def configure_seamless_m4t():
model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt"
model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name)
vocoder_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots"
vocoder_dir = (
f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots"
)
available_vocoder_versions = os.listdir(vocoder_dir)
latest_vocoder_version = sorted(available_vocoder_versions)[-1]
vocoder_name = "vocoder_36langs.pt"
vocoder_path = os.path.join(os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name)
vocoder_path = os.path.join(
os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name
)
tokenizer_name = "tokenizer.model"
tokenizer_path = os.path.join(os.getcwd(), model_dir, latest_model_version, tokenizer_name)
tokenizer_path = os.path.join(
os.getcwd(), model_dir, latest_model_version, tokenizer_name
)
model_yaml_data['checkpoint'] = f"file:/{model_path}"
vocoder_yaml_data['checkpoint'] = f"file:/{vocoder_path}"
unity_100_yaml_data['tokenizer'] = f"file:/{tokenizer_path}"
unity_200_yaml_data['tokenizer'] = f"file:/{tokenizer_path}"
model_yaml_data["checkpoint"] = f"file:/{model_path}"
vocoder_yaml_data["checkpoint"] = f"file:/{vocoder_path}"
unity_100_yaml_data["tokenizer"] = f"file:/{tokenizer_path}"
unity_200_yaml_data["tokenizer"] = f"file:/{tokenizer_path}"
with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'w') as file:
with open(f"{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "w") as file:
yaml.dump(model_yaml_data, file)
with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'w') as file:
with open(f"{ASSETS_DIR}/vocoder_36langs.yaml", "w") as file:
yaml.dump(vocoder_yaml_data, file)
with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'w') as file:
with open(f"{ASSETS_DIR}/unity_nllb-100.yaml", "w") as file:
yaml.dump(unity_100_yaml_data, file)
with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'w') as file:
with open(f"{ASSETS_DIR}/unity_nllb-200.yaml", "w") as file:
yaml.dump(unity_200_yaml_data, file)
@@ -109,7 +118,7 @@ transcriber_image = (
"torchaudio",
"fairseq2",
"pyyaml",
"hf-transfer~=0.1"
"hf-transfer~=0.1",
)
.run_function(install_seamless_communication)
.run_function(download_seamlessm4t_model)
@@ -144,7 +153,7 @@ class Translator:
SEAMLESSM4T_MODEL_CARD_NAME,
SEAMLESSM4T_VOCODER_CARD_NAME,
torch.device(self.device),
dtype=torch.float32
dtype=torch.float32,
)
@method()
@@ -158,32 +167,210 @@ class Translator:
"""
# TODO: Enhance with complete list of lang codes
seamless_lang_code = {
# Amharic
"am": "amh",
# Modern Standard Arabic
"ar": "arb",
# Moroccan Arabic
# (No 2-letter code)
# Egyptian Arabic
# (No 2-letter code)
# Assamese
"as": "asm",
# North Azerbaijani
"az": "azj",
# Belarusian
"be": "bel",
# Bengali
"bn": "ben",
# Bosnian
"bs": "bos",
# Bulgarian
"bg": "bul",
# Catalan
"ca": "cat",
# Cebuano
"ceb": "ceb",
# Czech
"cs": "ces",
# Central Kurdish
"ckb": "ckb",
# Mandarin Chinese (Simplified)
"zh": "cmn",
# Mandarin Chinese (Traditional)
# (No separate 2-letter code)
# Welsh
"cy": "cym",
# Danish
"da": "dan",
# German
"de": "deu",
# Greek
"el": "ell",
# English
"en": "eng",
"fr": "fra"
# Estonian
"et": "est",
# Basque
"eu": "eus",
# Finnish
"fi": "fin",
# French
"fr": "fra",
# West Central Oromo
# (No 2-letter code)
# Irish
"ga": "gle",
# Galician
"gl": "glg",
# Gujarati
"gu": "guj",
# Hebrew
"he": "heb",
# Hindi
"hi": "hin",
# Croatian
"hr": "hrv",
# Hungarian
"hu": "hun",
# Armenian
"hy": "hye",
# Igbo
"ig": "ibo",
# Indonesian
"id": "ind",
# Icelandic
"is": "isl",
# Italian
"it": "ita",
# Javanese
"jv": "jav",
# Japanese
"ja": "jpn",
# Kannada
"kn": "kan",
# Georgian
"ka": "kat",
# Kazakh
"kk": "kaz",
# Halh Mongolian
# (No 2-letter code)
# Khmer
"km": "khm",
# Kyrgyz
"ky": "kir",
# Korean
"ko": "kor",
# Lao
"lo": "lao",
# Lithuanian
"lt": "lit",
# Ganda
"lg": "lug",
# Luo
"luo": "luo",
# Standard Latvian
"lv": "lvs",
# Maithili
# (No 2-letter code)
# Malayalam
"ml": "mal",
# Marathi
"mr": "mar",
# Macedonian
"mk": "mkd",
# Maltese
"mt": "mlt",
# Meitei
# (No 2-letter code)
# Burmese
"my": "mya",
# Dutch
"nl": "nld",
# Norwegian Nynorsk
"nn": "nno",
# Norwegian Bokmål
"nb": "nob",
# Nepali
"ne": "npi",
# Nyanja
"ny": "nya",
# Odia
"or": "ory",
# Punjabi
"pa": "pan",
# Southern Pashto
# (No 2-letter code)
# Western Persian
"fa": "pes",
# Polish
"pl": "pol",
# Portuguese
"pt": "por",
# Romanian
"ro": "ron",
# Russian
"ru": "rus",
# Slovak
"sk": "slk",
# Slovenian
"sl": "slv",
# Shona
"sn": "sna",
# Sindhi
"sd": "snd",
# Somali
"so": "som",
# Spanish
"es": "spa",
# Serbian
"sr": "srp",
# Swedish
"sv": "swe",
# Swahili
"sw": "swh",
# Tamil
"ta": "tam",
# Telugu
"te": "tel",
# Tajik
"tg": "tgk",
# Tagalog
"tl": "tgl",
# Thai
"th": "tha",
# Turkish
"tr": "tur",
# Ukrainian
"uk": "ukr",
# Urdu
"ur": "urd",
# Northern Uzbek
"uz": "uzn",
# Vietnamese
"vi": "vie",
# Yoruba
"yo": "yor",
# Cantonese
# (No separate 2-letter code)
# Zulu
"zu": "zul",
}
return seamless_lang_code.get(lang_code, "eng")
@method()
def translate_text(
self,
text: str,
source_language: str,
target_language: str
):
def translate_text(self, text: str, source_language: str, target_language: str):
with self.lock:
translated_text, _, _ = self.translator.predict(
text,
"t2tt",
src_lang=self.get_seamless_lang_code(source_language),
tgt_lang=self.get_seamless_lang_code(target_language),
ngram_filtering=True
ngram_filtering=True,
)
return {
"text": {
source_language: text,
target_language: str(translated_text)
}
}
return {"text": {source_language: text, target_language: str(translated_text)}}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------