diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 8641f12f..996a73d7 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -23,14 +23,23 @@ TRANSLATION_MODEL_DIR = "translation" stub = Stub(name="reflector-transtest") -def download_models(): +def download_whisper(): from faster_whisper.utils import download_model - from huggingface_hub import snapshot_download print("Downloading Whisper model") - download_model(WHISPER_MODEL) + download_model(WHISPER_MODEL, cache_dir=MODEL_DIR) print("Whisper model downloaded") + +def download_translation_model(): + from huggingface_hub import snapshot_download + + print("Downloading Translation model") + ignore_patterns = ["*.ot"] + snapshot_download(TRANSLATION_MODEL, local_dir=MODEL_DIR, ignore_patterns=ignore_patterns) + print("Translation model downloaded") + + def migrate_cache_llm(): """ XXX The cache for model files in Transformers v4.22.0 has been updated. @@ -44,13 +53,6 @@ def migrate_cache_llm(): move_cache() print("LLM cache moved") -def download_translation_model(): - from huggingface_hub import snapshot_download - - print("Downloading Translation model") - ignore_patterns = ["*.ot"] - snapshot_download(TRANSLATION_MODEL, cache_dir=MODEL_DIR, ignore_patterns=ignore_patterns) - print("Translation model downloaded") whisper_image = ( Image.debian_slim(python_version="3.10.8")