persistent model storage

This commit is contained in:
Gokul Mohanarangan
2023-09-08 00:22:38 +05:30
parent e613157fd6
commit 2bed312e64
2 changed files with 32 additions and 31 deletions

View File

@@ -6,6 +6,7 @@ Reflector GPU backend - transcriber
import os
import tempfile
import modal
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
@@ -17,12 +18,13 @@ WHISPER_NUM_WORKERS: int = 1
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
MODEL_DIR = "model"
IMAGE_MODEL_DIR = "/root/transcription_models"
volume = modal.NetworkFileSystem.persisted("reflector-transcribe-models")
stub = Stub(name="reflector-transtest")
stub = Stub(name="reflector-transtest1")
def download_whisper(cache_dir: str = None):
def download_whisper(cache_dir: str | None = None):
from faster_whisper.utils import download_model
print("Downloading Whisper model")
@@ -30,32 +32,24 @@ def download_whisper(cache_dir: str = None):
print("Whisper model downloaded")
def download_translation_model(cache_dir: str = None):
def download_translation_model(cache_dir: str | None = None):
from huggingface_hub import snapshot_download
print("Downloading Translation model")
ignore_patterns = ["*.ot"]
snapshot_download(TRANSLATION_MODEL, cache_dir=cache_dir, ignore_patterns=ignore_patterns)
snapshot_download(
TRANSLATION_MODEL,
cache_dir=cache_dir,
ignore_patterns=ignore_patterns
)
print("Translation model downloaded")
def download_models():
download_whisper(cache_dir=MODEL_DIR)
download_translation_model(cache_dir=MODEL_DIR)
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache()
print("LLM cache moved")
print(f"Downloading models to {IMAGE_MODEL_DIR=}")
download_whisper(cache_dir=IMAGE_MODEL_DIR)
download_translation_model(cache_dir=IMAGE_MODEL_DIR)
print(f"Model downloads complete.")
whisper_image = (
@@ -71,7 +65,6 @@ whisper_image = (
"huggingface_hub==0.16.4",
)
.run_function(download_models)
.run_function(migrate_cache_llm)
.env(
{
"LD_LIBRARY_PATH": (
@@ -87,6 +80,7 @@ whisper_image = (
gpu="A10G",
container_idle_timeout=60,
image=whisper_image,
network_file_systems={IMAGE_MODEL_DIR: volume},
)
class Whisper:
def __enter__(self):
@@ -101,14 +95,15 @@ class Whisper:
device=self.device,
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
download_root=IMAGE_MODEL_DIR
)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
TRANSLATION_MODEL,
cache_dir=MODEL_DIR
cache_dir=IMAGE_MODEL_DIR
).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
TRANSLATION_MODEL,
cache_dir=MODEL_DIR
cache_dir=IMAGE_MODEL_DIR
)
@method()