update to use cache dir

This commit is contained in:
Gokul Mohanarangan
2023-09-05 14:28:48 +05:30
parent 6b84bbb4f6
commit e613157fd6

View File

@@ -14,32 +14,36 @@ WHISPER_MODEL: str = "large-v2"
WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1
MODEL_DIR = "/model"
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
TRANSLATION_MODEL_DIR = "translation"
MODEL_DIR = "model"
stub = Stub(name="reflector-transtest")
def download_whisper():
def download_whisper(cache_dir: str = None):
from faster_whisper.utils import download_model
print("Downloading Whisper model")
download_model(WHISPER_MODEL, cache_dir=MODEL_DIR)
download_model(WHISPER_MODEL, cache_dir=cache_dir)
print("Whisper model downloaded")
def download_translation_model():
def download_translation_model(cache_dir: str = None):
from huggingface_hub import snapshot_download
print("Downloading Translation model")
ignore_patterns = ["*.ot"]
snapshot_download(TRANSLATION_MODEL, local_dir=MODEL_DIR, ignore_patterns=ignore_patterns)
snapshot_download(TRANSLATION_MODEL, cache_dir=cache_dir, ignore_patterns=ignore_patterns)
print("Translation model downloaded")
def download_models():
download_whisper(cache_dir=MODEL_DIR)
download_translation_model(cache_dir=MODEL_DIR)
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
@@ -99,10 +103,13 @@ class Whisper:
num_workers=WHISPER_NUM_WORKERS,
)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
TRANSLATION_MODEL_DIR
TRANSLATION_MODEL,
cache_dir=MODEL_DIR
).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
TRANSLATION_MODEL,
cache_dir=MODEL_DIR
)
@method()
def warmup(self):