diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 996a73d7..e97a90d4 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -14,32 +14,36 @@ WHISPER_MODEL: str = "large-v2" WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_NUM_WORKERS: int = 1 -MODEL_DIR = "/model" - # Translation Model TRANSLATION_MODEL = "facebook/m2m100_418M" -TRANSLATION_MODEL_DIR = "translation" + +MODEL_DIR = "model" stub = Stub(name="reflector-transtest") -def download_whisper(): +def download_whisper(cache_dir: str = None): from faster_whisper.utils import download_model print("Downloading Whisper model") - download_model(WHISPER_MODEL, cache_dir=MODEL_DIR) + download_model(WHISPER_MODEL, cache_dir=cache_dir) print("Whisper model downloaded") -def download_translation_model(): +def download_translation_model(cache_dir: str = None): from huggingface_hub import snapshot_download print("Downloading Translation model") ignore_patterns = ["*.ot"] - snapshot_download(TRANSLATION_MODEL, local_dir=MODEL_DIR, ignore_patterns=ignore_patterns) + snapshot_download(TRANSLATION_MODEL, cache_dir=cache_dir, ignore_patterns=ignore_patterns) print("Translation model downloaded") +def download_models(): + download_whisper(cache_dir=MODEL_DIR) + download_translation_model(cache_dir=MODEL_DIR) + + def migrate_cache_llm(): """ XXX The cache for model files in Transformers v4.22.0 has been updated. @@ -99,10 +103,13 @@ class Whisper: num_workers=WHISPER_NUM_WORKERS, ) self.translation_model = M2M100ForConditionalGeneration.from_pretrained( - TRANSLATION_MODEL_DIR + TRANSLATION_MODEL, + cache_dir=MODEL_DIR ).to(self.device) - self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL) - + self.translation_tokenizer = M2M100Tokenizer.from_pretrained( + TRANSLATION_MODEL, + cache_dir=MODEL_DIR + ) @method() def warmup(self):