mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
update to use cache dir
This commit is contained in:
@@ -14,32 +14,36 @@ WHISPER_MODEL: str = "large-v2"
|
||||
WHISPER_COMPUTE_TYPE: str = "float16"
|
||||
WHISPER_NUM_WORKERS: int = 1
|
||||
|
||||
MODEL_DIR = "/model"
|
||||
|
||||
# Translation Model
|
||||
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
||||
TRANSLATION_MODEL_DIR = "translation"
|
||||
|
||||
MODEL_DIR = "model"
|
||||
|
||||
stub = Stub(name="reflector-transtest")
|
||||
|
||||
|
||||
def download_whisper():
|
||||
def download_whisper(cache_dir: str = None):
|
||||
from faster_whisper.utils import download_model
|
||||
|
||||
print("Downloading Whisper model")
|
||||
download_model(WHISPER_MODEL, cache_dir=MODEL_DIR)
|
||||
download_model(WHISPER_MODEL, cache_dir=cache_dir)
|
||||
print("Whisper model downloaded")
|
||||
|
||||
|
||||
def download_translation_model():
|
||||
def download_translation_model(cache_dir: str = None):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading Translation model")
|
||||
ignore_patterns = ["*.ot"]
|
||||
snapshot_download(TRANSLATION_MODEL, local_dir=MODEL_DIR, ignore_patterns=ignore_patterns)
|
||||
snapshot_download(TRANSLATION_MODEL, cache_dir=cache_dir, ignore_patterns=ignore_patterns)
|
||||
print("Translation model downloaded")
|
||||
|
||||
|
||||
def download_models():
|
||||
download_whisper(cache_dir=MODEL_DIR)
|
||||
download_translation_model(cache_dir=MODEL_DIR)
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
@@ -99,10 +103,13 @@ class Whisper:
|
||||
num_workers=WHISPER_NUM_WORKERS,
|
||||
)
|
||||
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
TRANSLATION_MODEL_DIR
|
||||
TRANSLATION_MODEL,
|
||||
cache_dir=MODEL_DIR
|
||||
).to(self.device)
|
||||
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
|
||||
|
||||
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
|
||||
TRANSLATION_MODEL,
|
||||
cache_dir=MODEL_DIR
|
||||
)
|
||||
|
||||
@method()
|
||||
def warmup(self):
|
||||
|
||||
Reference in New Issue
Block a user