diff --git a/server/gpu/modal/reflector_translator.py b/server/gpu/modal/reflector_translator.py index 6b035174..2986e002 100644 --- a/server/gpu/modal/reflector_translator.py +++ b/server/gpu/modal/reflector_translator.py @@ -26,8 +26,11 @@ stub = Stub(name="reflector-translator") def install_seamless_communication(): import os import subprocess + initial_dir = os.getcwd() - subprocess.run(["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"]) + subprocess.run( + ["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"] + ) subprocess.run(["rm", "-rf", "seamless_communication"]) subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"]) os.chdir("seamless_communication") @@ -54,13 +57,13 @@ def configure_seamless_m4t(): ASSETS_DIR: str = "./seamless_communication/src/seamless_communication/assets/cards" - with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'r') as file: + with open(f"{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "r") as file: model_yaml_data = yaml.load(file, Loader=yaml.FullLoader) - with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'r') as file: + with open(f"{ASSETS_DIR}/vocoder_36langs.yaml", "r") as file: vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader) - with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'r') as file: + with open(f"{ASSETS_DIR}/unity_nllb-100.yaml", "r") as file: unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader) - with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'r') as file: + with open(f"{ASSETS_DIR}/unity_nllb-200.yaml", "r") as file: unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader) model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots" @@ -69,27 +72,33 @@ def configure_seamless_m4t(): model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt" model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name) - vocoder_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots" + vocoder_dir = ( + f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots" + ) available_vocoder_versions = os.listdir(vocoder_dir) latest_vocoder_version = sorted(available_vocoder_versions)[-1] vocoder_name = "vocoder_36langs.pt" - vocoder_path = os.path.join(os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name) + vocoder_path = os.path.join( + os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name + ) tokenizer_name = "tokenizer.model" - tokenizer_path = os.path.join(os.getcwd(), model_dir, latest_model_version, tokenizer_name) + tokenizer_path = os.path.join( + os.getcwd(), model_dir, latest_model_version, tokenizer_name + ) - model_yaml_data['checkpoint'] = f"file:/{model_path}" - vocoder_yaml_data['checkpoint'] = f"file:/{vocoder_path}" - unity_100_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" - unity_200_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" + model_yaml_data["checkpoint"] = f"file:/{model_path}" + vocoder_yaml_data["checkpoint"] = f"file:/{vocoder_path}" + unity_100_yaml_data["tokenizer"] = f"file:/{tokenizer_path}" + unity_200_yaml_data["tokenizer"] = f"file:/{tokenizer_path}" - with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'w') as file: + with open(f"{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "w") as file: yaml.dump(model_yaml_data, file) - with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'w') as file: + with open(f"{ASSETS_DIR}/vocoder_36langs.yaml", "w") as file: yaml.dump(vocoder_yaml_data, file) - with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'w') as file: + with open(f"{ASSETS_DIR}/unity_nllb-100.yaml", "w") as file: yaml.dump(unity_100_yaml_data, file) - with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'w') as file: + with open(f"{ASSETS_DIR}/unity_nllb-200.yaml", "w") as file: yaml.dump(unity_200_yaml_data, file) @@ -109,7 +118,7 @@ transcriber_image = ( "torchaudio", "fairseq2", "pyyaml", - "hf-transfer~=0.1" + "hf-transfer~=0.1", ) .run_function(install_seamless_communication) .run_function(download_seamlessm4t_model) @@ -144,7 +153,7 @@ class Translator: SEAMLESSM4T_MODEL_CARD_NAME, SEAMLESSM4T_VOCODER_CARD_NAME, torch.device(self.device), - dtype=torch.float32 + dtype=torch.float32, ) @method() @@ -158,32 +167,210 @@ class Translator: """ # TODO: Enhance with complete list of lang codes seamless_lang_code = { + # Amharic + "am": "amh", + # Modern Standard Arabic + "ar": "arb", + # Moroccan Arabic + # (No 2-letter code) + # Egyptian Arabic + # (No 2-letter code) + # Assamese + "as": "asm", + # North Azerbaijani + "az": "azj", + # Belarusian + "be": "bel", + # Bengali + "bn": "ben", + # Bosnian + "bs": "bos", + # Bulgarian + "bg": "bul", + # Catalan + "ca": "cat", + # Cebuano + "ceb": "ceb", + # Czech + "cs": "ces", + # Central Kurdish + "ckb": "ckb", + # Mandarin Chinese (Simplified) + "zh": "cmn", + # Mandarin Chinese (Traditional) + # (No separate 2-letter code) + # Welsh + "cy": "cym", + # Danish + "da": "dan", + # German + "de": "deu", + # Greek + "el": "ell", + # English "en": "eng", - "fr": "fra" + # Estonian + "et": "est", + # Basque + "eu": "eus", + # Finnish + "fi": "fin", + # French + "fr": "fra", + # West Central Oromo + # (No 2-letter code) + # Irish + "ga": "gle", + # Galician + "gl": "glg", + # Gujarati + "gu": "guj", + # Hebrew + "he": "heb", + # Hindi + "hi": "hin", + # Croatian + "hr": "hrv", + # Hungarian + "hu": "hun", + # Armenian + "hy": "hye", + # Igbo + "ig": "ibo", + # Indonesian + "id": "ind", + # Icelandic + "is": "isl", + # Italian + "it": "ita", + # Javanese + "jv": "jav", + # Japanese + "ja": "jpn", + # Kannada + "kn": "kan", + # Georgian + "ka": "kat", + # Kazakh + "kk": "kaz", + # Halh Mongolian + # (No 2-letter code) + # Khmer + "km": "khm", + # Kyrgyz + "ky": "kir", + # Korean + "ko": "kor", + # Lao + "lo": "lao", + # Lithuanian + "lt": "lit", + # Ganda + "lg": "lug", + # Luo + "luo": "luo", + # Standard Latvian + "lv": "lvs", + # Maithili + # (No 2-letter code) + # Malayalam + "ml": "mal", + # Marathi + "mr": "mar", + # Macedonian + "mk": "mkd", + # Maltese + "mt": "mlt", + # Meitei + # (No 2-letter code) + # Burmese + "my": "mya", + # Dutch + "nl": "nld", + # Norwegian Nynorsk + "nn": "nno", + # Norwegian Bokmål + "nb": "nob", + # Nepali + "ne": "npi", + # Nyanja + "ny": "nya", + # Odia + "or": "ory", + # Punjabi + "pa": "pan", + # Southern Pashto + # (No 2-letter code) + # Western Persian + "fa": "pes", + # Polish + "pl": "pol", + # Portuguese + "pt": "por", + # Romanian + "ro": "ron", + # Russian + "ru": "rus", + # Slovak + "sk": "slk", + # Slovenian + "sl": "slv", + # Shona + "sn": "sna", + # Sindhi + "sd": "snd", + # Somali + "so": "som", + # Spanish + "es": "spa", + # Serbian + "sr": "srp", + # Swedish + "sv": "swe", + # Swahili + "sw": "swh", + # Tamil + "ta": "tam", + # Telugu + "te": "tel", + # Tajik + "tg": "tgk", + # Tagalog + "tl": "tgl", + # Thai + "th": "tha", + # Turkish + "tr": "tur", + # Ukrainian + "uk": "ukr", + # Urdu + "ur": "urd", + # Northern Uzbek + "uz": "uzn", + # Vietnamese + "vi": "vie", + # Yoruba + "yo": "yor", + # Cantonese + # (No separate 2-letter code) + # Zulu + "zu": "zul", } return seamless_lang_code.get(lang_code, "eng") @method() - def translate_text( - self, - text: str, - source_language: str, - target_language: str - ): + def translate_text(self, text: str, source_language: str, target_language: str): with self.lock: translated_text, _, _ = self.translator.predict( text, "t2tt", src_lang=self.get_seamless_lang_code(source_language), tgt_lang=self.get_seamless_lang_code(target_language), - ngram_filtering=True + ngram_filtering=True, ) - return { - "text": { - source_language: text, - target_language: str(translated_text) - } - } + return {"text": {source_language: text, target_language: str(translated_text)}} + + # ------------------------------------------------------------------- # Web API # ------------------------------------------------------------------- @@ -222,9 +409,9 @@ def web(): @app.post("/translate", dependencies=[Depends(apikey_auth)]) def translate( - text: str, - source_language: Annotated[str, Body(...)] = "en", - target_language: Annotated[str, Body(...)] = "fr", + text: str, + source_language: Annotated[str, Body(...)] = "en", + target_language: Annotated[str, Body(...)] = "fr", ) -> TranslateResponse: func = translatorstub.translate_text.spawn( text=text,