diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 84b24bb7..8641f12f 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -18,6 +18,7 @@ MODEL_DIR = "/model" # Translation Model TRANSLATION_MODEL = "facebook/m2m100_418M" +TRANSLATION_MODEL_DIR = "translation" stub = Stub(name="reflector-transtest") @@ -30,11 +31,6 @@ def download_models(): download_model(WHISPER_MODEL) print("Whisper model downloaded") - print("Downloading Translation model") - ignore_patterns = ["*.ot"] - snapshot_download(TRANSLATION_MODEL, cache_dir=MODEL_DIR, ignore_patterns=ignore_patterns) - print("Translation model downloaded") - def migrate_cache_llm(): """ XXX The cache for model files in Transformers v4.22.0 has been updated. @@ -48,6 +44,14 @@ def migrate_cache_llm(): move_cache() print("LLM cache moved") +def download_translation_model(): + from huggingface_hub import snapshot_download + + print("Downloading Translation model") + ignore_patterns = ["*.ot"] + snapshot_download(TRANSLATION_MODEL, cache_dir=MODEL_DIR, ignore_patterns=ignore_patterns) + print("Translation model downloaded") + whisper_image = ( Image.debian_slim(python_version="3.10.8") .apt_install("git") @@ -58,6 +62,7 @@ whisper_image = ( "transformers", "sentencepiece", "protobuf", + "huggingface_hub==0.16.4", ) .run_function(download_models) .run_function(migrate_cache_llm) @@ -92,13 +97,9 @@ class Whisper: num_workers=WHISPER_NUM_WORKERS, ) self.translation_model = M2M100ForConditionalGeneration.from_pretrained( - TRANSLATION_MODEL, - cache_dir=TRANSCRIPTION_MODEL_DIR + TRANSLATION_MODEL_DIR ).to(self.device) - self.translation_tokenizer = M2M100Tokenizer.from_pretrained( - TRANSLATION_MODEL, - cache_dir=TRANSCRIPTION_MODEL_DIR - ) + self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL) @method()