mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
fix loading shards from local cache (#313)
This commit is contained in:
@@ -81,7 +81,8 @@ class LLM:
|
|||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||||
cache_dir=IMAGE_MODEL_DIR
|
cache_dir=IMAGE_MODEL_DIR,
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# JSONFormer doesn't yet support generation configs
|
# JSONFormer doesn't yet support generation configs
|
||||||
@@ -96,7 +97,8 @@ class LLM:
|
|||||||
print("Instance llm tokenizer")
|
print("Instance llm tokenizer")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
cache_dir=IMAGE_MODEL_DIR
|
cache_dir=IMAGE_MODEL_DIR,
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# move model to gpu
|
# move model to gpu
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ LLM_LOW_CPU_MEM_USAGE: bool = True
|
|||||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||||
LLM_MAX_NEW_TOKENS: int = 300
|
LLM_MAX_NEW_TOKENS: int = 300
|
||||||
|
|
||||||
IMAGE_MODEL_DIR = "/root/llm_models"
|
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
|
||||||
|
|
||||||
stub = Stub(name="reflector-llm-zephyr")
|
stub = Stub(name="reflector-llm-zephyr")
|
||||||
|
|
||||||
@@ -81,7 +81,8 @@ class LLM:
|
|||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||||
cache_dir=IMAGE_MODEL_DIR
|
cache_dir=IMAGE_MODEL_DIR,
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# JSONFormer doesn't yet support generation configs
|
# JSONFormer doesn't yet support generation configs
|
||||||
@@ -96,7 +97,8 @@ class LLM:
|
|||||||
print("Instance llm tokenizer")
|
print("Instance llm tokenizer")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
cache_dir=IMAGE_MODEL_DIR
|
cache_dir=IMAGE_MODEL_DIR,
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
||||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
||||||
|
|||||||
@@ -95,7 +95,8 @@ class Transcriber:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
compute_type=WHISPER_COMPUTE_TYPE,
|
compute_type=WHISPER_COMPUTE_TYPE,
|
||||||
num_workers=WHISPER_NUM_WORKERS,
|
num_workers=WHISPER_NUM_WORKERS,
|
||||||
download_root=WHISPER_MODEL_DIR
|
download_root=WHISPER_MODEL_DIR,
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
|
|||||||
Reference in New Issue
Block a user