This commit is contained in:
Gokul Mohanarangan
2023-08-30 10:43:51 +05:30
parent df078f7bd6
commit 012390d0aa
2 changed files with 36 additions and 8 deletions

View File

@@ -38,7 +38,7 @@ def migrate_cache_llm():
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache()
move_cache(cache_dir=IMAGE_MODEL_DIR)
print("LLM cache moved")

View File

@@ -13,19 +13,40 @@ from pydantic import BaseModel
WHISPER_MODEL: str = "large-v2"
WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1
WHISPER_CACHE_DIR: str = "/cache/whisper"
MODEL_DIR = "/model"
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
stub = Stub(name="reflector-transcriber")
stub = Stub(name="reflector-transtest")
def download_whisper():
def download_models():
from faster_whisper.utils import download_model
from huggingface_hub import snapshot_download
download_model(WHISPER_MODEL, local_files_only=False)
print("Downloading Whisper model")
download_model(WHISPER_MODEL)
print("Whisper model downloaded")
print("Downloading Translation model")
ignore_patterns = ["*.ot"]
snapshot_download(TRANSLATION_MODEL, cache_dir=MODEL_DIR, ignore_patterns=ignore_patterns)
print("Translation model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache()
print("LLM cache moved")
whisper_image = (
Image.debian_slim(python_version="3.10.8")
@@ -38,7 +59,8 @@ whisper_image = (
"sentencepiece",
"protobuf",
)
.run_function(download_whisper)
.run_function(download_models)
.run_function(migrate_cache_llm)
.env(
{
"LD_LIBRARY_PATH": (
@@ -69,8 +91,14 @@ class Whisper:
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
TRANSLATION_MODEL,
cache_dir=TRANSCRIPTION_MODEL_DIR
).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
TRANSLATION_MODEL,
cache_dir=TRANSCRIPTION_MODEL_DIR
)
@method()