change model download

This commit is contained in:
Gokul Mohanarangan
2023-08-30 13:00:42 +05:30
parent 012390d0aa
commit 61e24969e4

View File

@@ -18,6 +18,7 @@ MODEL_DIR = "/model"
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
TRANSLATION_MODEL_DIR = "translation"
stub = Stub(name="reflector-transtest")
@@ -30,11 +31,6 @@ def download_models():
download_model(WHISPER_MODEL)
print("Whisper model downloaded")
print("Downloading Translation model")
ignore_patterns = ["*.ot"]
snapshot_download(TRANSLATION_MODEL, cache_dir=MODEL_DIR, ignore_patterns=ignore_patterns)
print("Translation model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
@@ -48,6 +44,14 @@ def migrate_cache_llm():
move_cache()
print("LLM cache moved")
def download_translation_model():
from huggingface_hub import snapshot_download
print("Downloading Translation model")
ignore_patterns = ["*.ot"]
snapshot_download(TRANSLATION_MODEL, cache_dir=MODEL_DIR, ignore_patterns=ignore_patterns)
print("Translation model downloaded")
whisper_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
@@ -58,6 +62,7 @@ whisper_image = (
"transformers",
"sentencepiece",
"protobuf",
"huggingface_hub==0.16.4",
)
.run_function(download_models)
.run_function(migrate_cache_llm)
@@ -92,13 +97,9 @@ class Whisper:
num_workers=WHISPER_NUM_WORKERS,
)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
TRANSLATION_MODEL,
cache_dir=TRANSCRIPTION_MODEL_DIR
TRANSLATION_MODEL_DIR
).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
TRANSLATION_MODEL,
cache_dir=TRANSCRIPTION_MODEL_DIR
)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
@method()