mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge pull request #216 from Monadical-SAS/llm-modal
Download and load LLMs from cache
This commit is contained in:
@@ -7,6 +7,7 @@ import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
from modal import Image, Secret, Stub, asgi_app, method
|
||||
|
||||
# LLM
|
||||
@@ -15,7 +16,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/model"
|
||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
||||
|
||||
stub = Stub(name="reflector-llm")
|
||||
|
||||
@@ -24,7 +25,7 @@ def download_llm():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading LLM model")
|
||||
snapshot_download(LLM_MODEL, local_dir=IMAGE_MODEL_DIR)
|
||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM model downloaded")
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ def migrate_cache_llm():
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache()
|
||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
@@ -77,9 +78,10 @@ class LLM:
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
IMAGE_MODEL_DIR,
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
)
|
||||
|
||||
# generation configuration
|
||||
@@ -91,7 +93,10 @@ class LLM:
|
||||
|
||||
# load tokenizer
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
)
|
||||
|
||||
# move model to gpu
|
||||
print("Move llm model to GPU")
|
||||
|
||||
@@ -6,6 +6,7 @@ Reflector GPU backend - transcriber
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import modal
|
||||
from modal import Image, Secret, Stub, asgi_app, method
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -13,18 +14,55 @@ from pydantic import BaseModel
|
||||
WHISPER_MODEL: str = "large-v2"
|
||||
WHISPER_COMPUTE_TYPE: str = "float16"
|
||||
WHISPER_NUM_WORKERS: int = 1
|
||||
WHISPER_CACHE_DIR: str = "/cache/whisper"
|
||||
|
||||
# Translation Model
|
||||
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
||||
|
||||
IMAGE_MODEL_DIR = "/root/transcription_models"
|
||||
|
||||
stub = Stub(name="reflector-transcriber")
|
||||
|
||||
|
||||
def download_whisper():
|
||||
def download_whisper(cache_dir: str | None = None):
|
||||
from faster_whisper.utils import download_model
|
||||
|
||||
download_model(WHISPER_MODEL, local_files_only=False)
|
||||
print("Downloading Whisper model")
|
||||
download_model(WHISPER_MODEL, cache_dir=cache_dir)
|
||||
print("Whisper model downloaded")
|
||||
|
||||
|
||||
def download_translation_model(cache_dir: str | None = None):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading Translation model")
|
||||
ignore_patterns = ["*.ot"]
|
||||
snapshot_download(
|
||||
TRANSLATION_MODEL,
|
||||
cache_dir=cache_dir,
|
||||
ignore_patterns=ignore_patterns
|
||||
)
|
||||
print("Translation model downloaded")
|
||||
|
||||
|
||||
def download_models():
|
||||
print(f"Downloading models to {IMAGE_MODEL_DIR=}")
|
||||
download_whisper(cache_dir=IMAGE_MODEL_DIR)
|
||||
download_translation_model(cache_dir=IMAGE_MODEL_DIR)
|
||||
print(f"Model downloads complete.")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
"""
|
||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
||||
Migrating your old cache. This is a one-time only operation. You can
|
||||
interrupt this and resume the migration later on by calling
|
||||
`transformers.utils.move_cache()`.
|
||||
"""
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
whisper_image = (
|
||||
@@ -37,8 +75,10 @@ whisper_image = (
|
||||
"transformers",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"huggingface_hub==0.16.4",
|
||||
)
|
||||
.run_function(download_whisper)
|
||||
.run_function(download_models)
|
||||
.run_function(migrate_cache_llm)
|
||||
.env(
|
||||
{
|
||||
"LD_LIBRARY_PATH": (
|
||||
@@ -68,10 +108,16 @@ class Whisper:
|
||||
device=self.device,
|
||||
compute_type=WHISPER_COMPUTE_TYPE,
|
||||
num_workers=WHISPER_NUM_WORKERS,
|
||||
download_root=IMAGE_MODEL_DIR
|
||||
)
|
||||
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
TRANSLATION_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
).to(self.device)
|
||||
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
|
||||
TRANSLATION_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
)
|
||||
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device)
|
||||
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
|
||||
|
||||
|
||||
@method()
|
||||
def warmup(self):
|
||||
|
||||
Reference in New Issue
Block a user