From 0cdd7037fbf4bb93dcf2d77cf9567d041d2ce71f Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Wed, 16 Aug 2023 14:03:25 +0530 Subject: [PATCH] wrap JSONFormer around LLM --- server/gpu/modal/reflector_llm.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index bf6f4cf5..315ff785 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -10,7 +10,7 @@ from modal import Image, method, Stub, asgi_app, Secret # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" -LLM_LOW_CPU_MEM_USAGE: bool = False +LLM_LOW_CPU_MEM_USAGE: bool = True LLM_TORCH_DTYPE: str = "bfloat16" LLM_MAX_NEW_TOKENS: int = 300 @@ -49,6 +49,8 @@ llm_image = ( "torch", "sentencepiece", "protobuf", + "jsonformer==0.12.0", + "accelerate==0.21.0", "einops==0.6.1", "hf-transfer~=0.1", "huggingface_hub==0.16.4", @@ -81,6 +83,7 @@ class LLM: # generation configuration print("Instance llm generation config") + # JSONFormer doesn't yet support generation configs, but keeping for future usage model.config.max_new_tokens = LLM_MAX_NEW_TOKENS gen_cfg = GenerationConfig.from_model_config(model.config) gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS @@ -97,6 +100,13 @@ class LLM: self.model = model self.tokenizer = tokenizer self.gen_cfg = gen_cfg + self.json_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "summary": {"type": "string"}, + }, + } def __exit__(self, *args): print("Exit llm") @@ -109,16 +119,17 @@ class LLM: @method() def generate(self, prompt: str): print(f"Generate {prompt=}") - # tokenize prompt - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( - self.model.device - ) - output = self.model.generate(input_ids, generation_config=self.gen_cfg) + import jsonformer + import json - # decode output - response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) + jsonformer_llm = jsonformer.Jsonformer(model=self.model, + tokenizer=self.tokenizer, + json_schema=self.json_schema, + prompt=prompt, + max_string_token_length=self.gen_cfg.max_new_tokens) + response = jsonformer_llm() print(f"Generated {response=}") - return {"text": response} + return {"text": json.dumps(response)} # -------------------------------------------------------------------