mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Feature additions (#210)
* initial * add LLM features * update LLM logic * update llm functions: change control flow * add generation config * update return types * update processors and tests * update rtc_offer * revert new title processor change * fix unit tests * add comments and fix HTTP 500 * adjust prompt * test with reflector app * revert new event for final title * update * move onus onto processors * move onus onto processors * stash * add provision for gen config * dynamically pack the LLM input using context length * tune final summary params * update consolidated class structures * update consolidated class structures * update precommit * add broadcast processors * working baseline * Organize LLMParams * minor fixes * minor fixes * minor fixes * fix unit tests * fix unit tests * fix unit tests * update tests * update tests * edit pipeline response events * update summary return types * configure tests * alembic db migration * change LLM response flow * edit main llm functions * edit main llm functions * change llm name and gen cf * Update transcript_topic_detector.py * PR review comments * checkpoint before db event migration * update DB migration of past events * update DB migration of past events * edit LLM classes * Delete unwanted file * remove List typing * remove List typing * update oobabooga API call * topic enhancements * update UI event handling * move ensure_casing to llm base * update tests * update tests
This commit is contained in:
@@ -55,7 +55,7 @@ llm_image = (
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4",
|
||||
"huggingface_hub==0.16.4"
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
.run_function(download_llm)
|
||||
@@ -73,8 +73,7 @@ llm_image = (
|
||||
class LLM:
|
||||
def __enter__(self):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
print("Instance llm model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -84,10 +83,11 @@ class LLM:
|
||||
cache_dir=IMAGE_MODEL_DIR
|
||||
)
|
||||
|
||||
# generation configuration
|
||||
# JSONFormer doesn't yet support generation configs
|
||||
print("Instance llm generation config")
|
||||
# JSONFormer doesn't yet support generation configs, but keeping for future usage
|
||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
# generation configuration
|
||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
|
||||
@@ -106,6 +106,7 @@ class LLM:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.gen_cfg = gen_cfg
|
||||
self.GenerationConfig = GenerationConfig
|
||||
|
||||
def __exit__(self, *args):
|
||||
print("Exit llm")
|
||||
@@ -116,34 +117,44 @@ class LLM:
|
||||
return {"status": "ok"}
|
||||
|
||||
@method()
|
||||
def generate(self, prompt: str, schema: str = None):
|
||||
def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict:
|
||||
"""
|
||||
Perform a generation action using the LLM
|
||||
"""
|
||||
print(f"Generate {prompt=}")
|
||||
# If a schema is given, conform to schema
|
||||
if schema:
|
||||
print(f"Schema {schema=}")
|
||||
if gen_cfg:
|
||||
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
|
||||
else:
|
||||
gen_cfg = self.gen_cfg
|
||||
|
||||
# If a gen_schema is given, conform to gen_schema
|
||||
if gen_schema:
|
||||
import jsonformer
|
||||
|
||||
jsonformer_llm = jsonformer.Jsonformer(model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=self.gen_cfg.max_new_tokens)
|
||||
print(f"Schema {gen_schema=}")
|
||||
jsonformer_llm = jsonformer.Jsonformer(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(gen_schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=gen_cfg.max_new_tokens
|
||||
)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
# If no schema, perform prompt only generation
|
||||
# If no gen_schema, perform prompt only generation
|
||||
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||
output = self.model.generate(input_ids, generation_config=gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||
response = response[len(prompt):]
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Web API
|
||||
# -------------------------------------------------------------------
|
||||
@@ -160,7 +171,7 @@ class LLM:
|
||||
def web():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
llmstub = LLM()
|
||||
|
||||
@@ -177,16 +188,16 @@ def web():
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
prompt: str
|
||||
schema_: Optional[dict] = Field(None, alias="schema")
|
||||
gen_schema: Optional[dict] = None
|
||||
gen_cfg: Optional[dict] = None
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
async def llm(
|
||||
req: LLMRequest,
|
||||
):
|
||||
if req.schema_:
|
||||
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema_))
|
||||
else:
|
||||
func = llmstub.generate.spawn(prompt=req.prompt)
|
||||
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
|
||||
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
|
||||
func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user