mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
persistent model storage
This commit is contained in:
@@ -7,6 +7,7 @@ import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
from modal import Image, Secret, Stub, asgi_app, method
|
||||
|
||||
# LLM
|
||||
@@ -15,16 +16,17 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
IMAGE_MODEL_DIR = "/model"
|
||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
||||
volume = modal.NetworkFileSystem.persisted("reflector-llm-models")
|
||||
|
||||
stub = Stub(name="reflector-llm")
|
||||
stub = Stub(name="reflector-llmtest1")
|
||||
|
||||
|
||||
def download_llm():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print("Downloading LLM model")
|
||||
snapshot_download(LLM_MODEL, local_dir=IMAGE_MODEL_DIR)
|
||||
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
||||
print("LLM model downloaded")
|
||||
|
||||
|
||||
@@ -38,7 +40,7 @@ def migrate_cache_llm():
|
||||
from transformers.utils.hub import move_cache
|
||||
|
||||
print("Moving LLM cache")
|
||||
move_cache(cache_dir=IMAGE_MODEL_DIR)
|
||||
move_cache()
|
||||
print("LLM cache moved")
|
||||
|
||||
|
||||
@@ -58,7 +60,6 @@ llm_image = (
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
.run_function(migrate_cache_llm)
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +69,7 @@ llm_image = (
|
||||
container_idle_timeout=60 * 5,
|
||||
concurrency_limit=2,
|
||||
image=llm_image,
|
||||
network_file_systems={IMAGE_MODEL_DIR: volume},
|
||||
)
|
||||
class LLM:
|
||||
def __enter__(self):
|
||||
@@ -77,9 +79,10 @@ class LLM:
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
IMAGE_MODEL_DIR,
|
||||
LLM_MODEL,
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
)
|
||||
|
||||
# generation configuration
|
||||
@@ -91,7 +94,10 @@ class LLM:
|
||||
|
||||
# load tokenizer
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
)
|
||||
|
||||
# move model to gpu
|
||||
print("Move llm model to GPU")
|
||||
|
||||
Reference in New Issue
Block a user