diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 1a3f77d6..9e20ff00 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -7,6 +7,7 @@ import json import os from typing import Optional +import modal from modal import Image, Secret, Stub, asgi_app, method # LLM @@ -15,7 +16,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True LLM_TORCH_DTYPE: str = "bfloat16" LLM_MAX_NEW_TOKENS: int = 300 -IMAGE_MODEL_DIR = "/model" +IMAGE_MODEL_DIR = "/root/llm_models" stub = Stub(name="reflector-llm") @@ -24,7 +25,7 @@ def download_llm(): from huggingface_hub import snapshot_download print("Downloading LLM model") - snapshot_download(LLM_MODEL, local_dir=IMAGE_MODEL_DIR) + snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR) print("LLM model downloaded") @@ -38,7 +39,7 @@ def migrate_cache_llm(): from transformers.utils.hub import move_cache print("Moving LLM cache") - move_cache() + move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR) print("LLM cache moved") @@ -77,9 +78,10 @@ class LLM: print("Instance llm model") model = AutoModelForCausalLM.from_pretrained( - IMAGE_MODEL_DIR, + LLM_MODEL, torch_dtype=getattr(torch, LLM_TORCH_DTYPE), low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE, + cache_dir=IMAGE_MODEL_DIR ) # generation configuration @@ -91,7 +93,10 @@ class LLM: # load tokenizer print("Instance llm tokenizer") - tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) + tokenizer = AutoTokenizer.from_pretrained( + LLM_MODEL, + cache_dir=IMAGE_MODEL_DIR + ) # move model to gpu print("Move llm model to GPU") diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index f06706c8..ff4caff7 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -6,6 +6,7 @@ Reflector GPU backend - transcriber import os import tempfile +import modal from modal import Image, Secret, Stub, asgi_app, method from pydantic import BaseModel @@ -13,18 +14,55 @@ from pydantic import BaseModel WHISPER_MODEL: str = "large-v2" WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_NUM_WORKERS: int = 1 -WHISPER_CACHE_DIR: str = "/cache/whisper" # Translation Model TRANSLATION_MODEL = "facebook/m2m100_418M" +IMAGE_MODEL_DIR = "/root/transcription_models" + stub = Stub(name="reflector-transcriber") -def download_whisper(): +def download_whisper(cache_dir: str | None = None): from faster_whisper.utils import download_model - download_model(WHISPER_MODEL, local_files_only=False) + print("Downloading Whisper model") + download_model(WHISPER_MODEL, cache_dir=cache_dir) + print("Whisper model downloaded") + + +def download_translation_model(cache_dir: str | None = None): + from huggingface_hub import snapshot_download + + print("Downloading Translation model") + ignore_patterns = ["*.ot"] + snapshot_download( + TRANSLATION_MODEL, + cache_dir=cache_dir, + ignore_patterns=ignore_patterns + ) + print("Translation model downloaded") + + +def download_models(): + print(f"Downloading models to {IMAGE_MODEL_DIR=}") + download_whisper(cache_dir=IMAGE_MODEL_DIR) + download_translation_model(cache_dir=IMAGE_MODEL_DIR) + print(f"Model downloads complete.") + + +def migrate_cache_llm(): + """ + XXX The cache for model files in Transformers v4.22.0 has been updated. + Migrating your old cache. This is a one-time only operation. You can + interrupt this and resume the migration later on by calling + `transformers.utils.move_cache()`. + """ + from transformers.utils.hub import move_cache + + print("Moving LLM cache") + move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR) + print("LLM cache moved") whisper_image = ( @@ -37,8 +75,10 @@ whisper_image = ( "transformers", "sentencepiece", "protobuf", + "huggingface_hub==0.16.4", ) - .run_function(download_whisper) + .run_function(download_models) + .run_function(migrate_cache_llm) .env( { "LD_LIBRARY_PATH": ( @@ -68,10 +108,16 @@ class Whisper: device=self.device, compute_type=WHISPER_COMPUTE_TYPE, num_workers=WHISPER_NUM_WORKERS, + download_root=IMAGE_MODEL_DIR + ) + self.translation_model = M2M100ForConditionalGeneration.from_pretrained( + TRANSLATION_MODEL, + cache_dir=IMAGE_MODEL_DIR + ).to(self.device) + self.translation_tokenizer = M2M100Tokenizer.from_pretrained( + TRANSLATION_MODEL, + cache_dir=IMAGE_MODEL_DIR ) - self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device) - self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL) - @method() def warmup(self):