mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
Update all modal deployments and change seamless configuration due to changes in src repo (#353)
* update all modal deployments and change seamless configuration due to change in src repo * add fixture
This commit is contained in:
@@ -55,15 +55,15 @@ def configure_seamless_m4t():
|
||||
|
||||
import yaml
|
||||
|
||||
ASSETS_DIR: str = "./seamless_communication/src/seamless_communication/assets/cards"
|
||||
CARDS_DIR: str = "./seamless_communication/src/seamless_communication/cards"
|
||||
|
||||
with open(f"{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "r") as file:
|
||||
with open(f"{CARDS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "r") as file:
|
||||
model_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
|
||||
with open(f"{ASSETS_DIR}/vocoder_36langs.yaml", "r") as file:
|
||||
with open(f"{CARDS_DIR}/vocoder_36langs.yaml", "r") as file:
|
||||
vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
|
||||
with open(f"{ASSETS_DIR}/unity_nllb-100.yaml", "r") as file:
|
||||
with open(f"{CARDS_DIR}/unity_nllb-100.yaml", "r") as file:
|
||||
unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
|
||||
with open(f"{ASSETS_DIR}/unity_nllb-200.yaml", "r") as file:
|
||||
with open(f"{CARDS_DIR}/unity_nllb-200.yaml", "r") as file:
|
||||
unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
|
||||
|
||||
model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots"
|
||||
@@ -87,18 +87,18 @@ def configure_seamless_m4t():
|
||||
os.getcwd(), model_dir, latest_model_version, tokenizer_name
|
||||
)
|
||||
|
||||
model_yaml_data["checkpoint"] = f"file:/{model_path}"
|
||||
vocoder_yaml_data["checkpoint"] = f"file:/{vocoder_path}"
|
||||
unity_100_yaml_data["tokenizer"] = f"file:/{tokenizer_path}"
|
||||
unity_200_yaml_data["tokenizer"] = f"file:/{tokenizer_path}"
|
||||
model_yaml_data["checkpoint"] = f"file://{model_path}"
|
||||
vocoder_yaml_data["checkpoint"] = f"file://{vocoder_path}"
|
||||
unity_100_yaml_data["tokenizer"] = f"file://{tokenizer_path}"
|
||||
unity_200_yaml_data["tokenizer"] = f"file://{tokenizer_path}"
|
||||
|
||||
with open(f"{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "w") as file:
|
||||
with open(f"{CARDS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml", "w") as file:
|
||||
yaml.dump(model_yaml_data, file)
|
||||
with open(f"{ASSETS_DIR}/vocoder_36langs.yaml", "w") as file:
|
||||
with open(f"{CARDS_DIR}/vocoder_36langs.yaml", "w") as file:
|
||||
yaml.dump(vocoder_yaml_data, file)
|
||||
with open(f"{ASSETS_DIR}/unity_nllb-100.yaml", "w") as file:
|
||||
with open(f"{CARDS_DIR}/unity_nllb-100.yaml", "w") as file:
|
||||
yaml.dump(unity_100_yaml_data, file)
|
||||
with open(f"{ASSETS_DIR}/unity_nllb-200.yaml", "w") as file:
|
||||
with open(f"{CARDS_DIR}/unity_nllb-200.yaml", "w") as file:
|
||||
yaml.dump(unity_200_yaml_data, file)
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ transcriber_image = (
|
||||
class Translator:
|
||||
def __enter__(self):
|
||||
import torch
|
||||
from seamless_communication.models.inference.translator import Translator
|
||||
from seamless_communication.inference.translator import Translator
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
@@ -363,14 +363,15 @@ class Translator:
|
||||
@method()
|
||||
def translate_text(self, text: str, source_language: str, target_language: str):
|
||||
with self.lock:
|
||||
translated_text, _, _ = self.translator.predict(
|
||||
translation_result, _ = self.translator.predict(
|
||||
text,
|
||||
"t2tt",
|
||||
src_lang=self.get_seamless_lang_code(source_language),
|
||||
tgt_lang=self.get_seamless_lang_code(target_language),
|
||||
ngram_filtering=True,
|
||||
unit_generation_ngram_filtering=True,
|
||||
)
|
||||
return {"text": {source_language: text, target_language: str(translated_text)}}
|
||||
translated_text = str(translation_result[0])
|
||||
return {"text": {source_language: text, target_language: translated_text}}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@@ -70,6 +70,7 @@ async def test_transcript_rtc_and_websocket(
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
nltk,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
):
|
||||
@@ -227,6 +228,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
nltk,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user