mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Upgrade modal apps
This commit is contained in:
@@ -3,13 +3,14 @@ Reflector GPU backend - LLM
|
||||
===========================
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
from modal import Image, Secret, App, asgi_app, method, enter, exit
|
||||
from modal import App, Image, Secret, asgi_app, enter, exit, method
|
||||
|
||||
# LLM
|
||||
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
|
||||
@@ -56,7 +57,7 @@ llm_image = (
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4"
|
||||
"huggingface_hub==0.16.4",
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
@@ -67,7 +68,7 @@ llm_image = (
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=60 * 5,
|
||||
container_idle_timeout=60 * 5,
|
||||
scaledown_window=60 * 5,
|
||||
allow_concurrent_inputs=10,
|
||||
image=llm_image,
|
||||
)
|
||||
@@ -83,7 +84,7 @@ class LLM:
|
||||
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
|
||||
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
@@ -97,9 +98,7 @@ class LLM:
|
||||
# load tokenizer
|
||||
print("Instance llm tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
LLM_MODEL,
|
||||
cache_dir=IMAGE_MODEL_DIR,
|
||||
local_files_only=True
|
||||
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
|
||||
)
|
||||
gen_cfg.pad_token_id = tokenizer.eos_token_id
|
||||
gen_cfg.eos_token_id = tokenizer.eos_token_id
|
||||
@@ -122,7 +121,9 @@ class LLM:
|
||||
print("Exit llm")
|
||||
|
||||
@method()
|
||||
def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict:
|
||||
def generate(
|
||||
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
|
||||
) -> dict:
|
||||
"""
|
||||
Perform a generation action using the LLM
|
||||
"""
|
||||
@@ -145,7 +146,7 @@ class LLM:
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(gen_schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=gen_cfg.max_new_tokens
|
||||
max_string_token_length=gen_cfg.max_new_tokens,
|
||||
)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
@@ -158,21 +159,22 @@ class LLM:
|
||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||
response = response[len(prompt):]
|
||||
response = {
|
||||
"long_summary": response
|
||||
}
|
||||
response = self.tokenizer.decode(
|
||||
output[0].cpu(), skip_special_tokens=True
|
||||
)
|
||||
response = response[len(prompt) :]
|
||||
response = {"long_summary": response}
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.function(
|
||||
container_idle_timeout=60 * 10,
|
||||
scaledown_window=60 * 10,
|
||||
timeout=60 * 5,
|
||||
allow_concurrent_inputs=30,
|
||||
secrets=[
|
||||
@@ -205,11 +207,13 @@ def web():
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
def llm(
|
||||
req: LLMRequest,
|
||||
req: LLMRequest,
|
||||
):
|
||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
||||
func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg)
|
||||
func = llmstub.generate.spawn(
|
||||
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user