mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
persistent model storage
This commit is contained in:
@@ -7,6 +7,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import modal
|
||||||
from modal import Image, Secret, Stub, asgi_app, method
|
from modal import Image, Secret, Stub, asgi_app, method
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
@@ -15,16 +16,17 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
|
|||||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||||
LLM_MAX_NEW_TOKENS: int = 300
|
LLM_MAX_NEW_TOKENS: int = 300
|
||||||
|
|
||||||
IMAGE_MODEL_DIR = "/model"
|
IMAGE_MODEL_DIR = "/root/llm_models"
|
||||||
|
volume = modal.NetworkFileSystem.persisted("reflector-llm-models")
|
||||||
|
|
||||||
stub = Stub(name="reflector-llm")
|
stub = Stub(name="reflector-llmtest1")
|
||||||
|
|
||||||
|
|
||||||
def download_llm():
|
def download_llm():
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
print("Downloading LLM model")
|
print("Downloading LLM model")
|
||||||
snapshot_download(LLM_MODEL, local_dir=IMAGE_MODEL_DIR)
|
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
|
||||||
print("LLM model downloaded")
|
print("LLM model downloaded")
|
||||||
|
|
||||||
|
|
||||||
@@ -38,7 +40,7 @@ def migrate_cache_llm():
|
|||||||
from transformers.utils.hub import move_cache
|
from transformers.utils.hub import move_cache
|
||||||
|
|
||||||
print("Moving LLM cache")
|
print("Moving LLM cache")
|
||||||
move_cache(cache_dir=IMAGE_MODEL_DIR)
|
move_cache()
|
||||||
print("LLM cache moved")
|
print("LLM cache moved")
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +60,6 @@ llm_image = (
|
|||||||
)
|
)
|
||||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||||
.run_function(download_llm)
|
.run_function(download_llm)
|
||||||
.run_function(migrate_cache_llm)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -68,6 +69,7 @@ llm_image = (
|
|||||||
container_idle_timeout=60 * 5,
|
container_idle_timeout=60 * 5,
|
||||||
concurrency_limit=2,
|
concurrency_limit=2,
|
||||||
image=llm_image,
|
image=llm_image,
|
||||||
|
network_file_systems={IMAGE_MODEL_DIR: volume},
|
||||||
)
|
)
|
||||||
class LLM:
|
class LLM:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@@ -77,9 +79,10 @@ class LLM:
|
|||||||
|
|
||||||
print("Instance llm model")
|
print("Instance llm model")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
IMAGE_MODEL_DIR,
|
LLM_MODEL,
|
||||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||||
|
cache_dir=IMAGE_MODEL_DIR
|
||||||
)
|
)
|
||||||
|
|
||||||
# generation configuration
|
# generation configuration
|
||||||
@@ -91,7 +94,10 @@ class LLM:
|
|||||||
|
|
||||||
# load tokenizer
|
# load tokenizer
|
||||||
print("Instance llm tokenizer")
|
print("Instance llm tokenizer")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
LLM_MODEL,
|
||||||
|
cache_dir=IMAGE_MODEL_DIR
|
||||||
|
)
|
||||||
|
|
||||||
# move model to gpu
|
# move model to gpu
|
||||||
print("Move llm model to GPU")
|
print("Move llm model to GPU")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Reflector GPU backend - transcriber
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
import modal
|
||||||
from modal import Image, Secret, Stub, asgi_app, method
|
from modal import Image, Secret, Stub, asgi_app, method
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -17,12 +18,13 @@ WHISPER_NUM_WORKERS: int = 1
|
|||||||
# Translation Model
|
# Translation Model
|
||||||
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
||||||
|
|
||||||
MODEL_DIR = "model"
|
IMAGE_MODEL_DIR = "/root/transcription_models"
|
||||||
|
volume = modal.NetworkFileSystem.persisted("reflector-transcribe-models")
|
||||||
|
|
||||||
stub = Stub(name="reflector-transtest")
|
stub = Stub(name="reflector-transtest1")
|
||||||
|
|
||||||
|
|
||||||
def download_whisper(cache_dir: str = None):
|
def download_whisper(cache_dir: str | None = None):
|
||||||
from faster_whisper.utils import download_model
|
from faster_whisper.utils import download_model
|
||||||
|
|
||||||
print("Downloading Whisper model")
|
print("Downloading Whisper model")
|
||||||
@@ -30,32 +32,24 @@ def download_whisper(cache_dir: str = None):
|
|||||||
print("Whisper model downloaded")
|
print("Whisper model downloaded")
|
||||||
|
|
||||||
|
|
||||||
def download_translation_model(cache_dir: str = None):
|
def download_translation_model(cache_dir: str | None = None):
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
print("Downloading Translation model")
|
print("Downloading Translation model")
|
||||||
ignore_patterns = ["*.ot"]
|
ignore_patterns = ["*.ot"]
|
||||||
snapshot_download(TRANSLATION_MODEL, cache_dir=cache_dir, ignore_patterns=ignore_patterns)
|
snapshot_download(
|
||||||
|
TRANSLATION_MODEL,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
ignore_patterns=ignore_patterns
|
||||||
|
)
|
||||||
print("Translation model downloaded")
|
print("Translation model downloaded")
|
||||||
|
|
||||||
|
|
||||||
def download_models():
|
def download_models():
|
||||||
download_whisper(cache_dir=MODEL_DIR)
|
print(f"Downloading models to {IMAGE_MODEL_DIR=}")
|
||||||
download_translation_model(cache_dir=MODEL_DIR)
|
download_whisper(cache_dir=IMAGE_MODEL_DIR)
|
||||||
|
download_translation_model(cache_dir=IMAGE_MODEL_DIR)
|
||||||
|
print(f"Model downloads complete.")
|
||||||
def migrate_cache_llm():
|
|
||||||
"""
|
|
||||||
XXX The cache for model files in Transformers v4.22.0 has been updated.
|
|
||||||
Migrating your old cache. This is a one-time only operation. You can
|
|
||||||
interrupt this and resume the migration later on by calling
|
|
||||||
`transformers.utils.move_cache()`.
|
|
||||||
"""
|
|
||||||
from transformers.utils.hub import move_cache
|
|
||||||
|
|
||||||
print("Moving LLM cache")
|
|
||||||
move_cache()
|
|
||||||
print("LLM cache moved")
|
|
||||||
|
|
||||||
|
|
||||||
whisper_image = (
|
whisper_image = (
|
||||||
@@ -71,7 +65,6 @@ whisper_image = (
|
|||||||
"huggingface_hub==0.16.4",
|
"huggingface_hub==0.16.4",
|
||||||
)
|
)
|
||||||
.run_function(download_models)
|
.run_function(download_models)
|
||||||
.run_function(migrate_cache_llm)
|
|
||||||
.env(
|
.env(
|
||||||
{
|
{
|
||||||
"LD_LIBRARY_PATH": (
|
"LD_LIBRARY_PATH": (
|
||||||
@@ -87,6 +80,7 @@ whisper_image = (
|
|||||||
gpu="A10G",
|
gpu="A10G",
|
||||||
container_idle_timeout=60,
|
container_idle_timeout=60,
|
||||||
image=whisper_image,
|
image=whisper_image,
|
||||||
|
network_file_systems={IMAGE_MODEL_DIR: volume},
|
||||||
)
|
)
|
||||||
class Whisper:
|
class Whisper:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@@ -101,14 +95,15 @@ class Whisper:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
compute_type=WHISPER_COMPUTE_TYPE,
|
compute_type=WHISPER_COMPUTE_TYPE,
|
||||||
num_workers=WHISPER_NUM_WORKERS,
|
num_workers=WHISPER_NUM_WORKERS,
|
||||||
|
download_root=IMAGE_MODEL_DIR
|
||||||
)
|
)
|
||||||
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
|
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(
|
||||||
TRANSLATION_MODEL,
|
TRANSLATION_MODEL,
|
||||||
cache_dir=MODEL_DIR
|
cache_dir=IMAGE_MODEL_DIR
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
|
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(
|
||||||
TRANSLATION_MODEL,
|
TRANSLATION_MODEL,
|
||||||
cache_dir=MODEL_DIR
|
cache_dir=IMAGE_MODEL_DIR
|
||||||
)
|
)
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
|
|||||||
Reference in New Issue
Block a user