add comments and log

This commit is contained in:
Gokul Mohanarangan
2023-08-17 09:33:59 +05:30
parent eb13a7bd64
commit 2e48f89fdc

View File

@@ -113,7 +113,9 @@ class LLM:
@method() @method()
def generate(self, prompt: str, schema: str = None): def generate(self, prompt: str, schema: str = None):
print(f"Generate {prompt=}") print(f"Generate {prompt=}")
# If a schema is given, conform to schema
if schema: if schema:
print(f"Schema {schema=}")
import ast import ast
import jsonformer import jsonformer
@@ -123,16 +125,17 @@ class LLM:
prompt=prompt, prompt=prompt,
max_string_token_length=self.gen_cfg.max_new_tokens) max_string_token_length=self.gen_cfg.max_new_tokens)
response = jsonformer_llm() response = jsonformer_llm()
print(f"Generated {response=}") else:
return {"text": response} # If no schema, perform prompt only generation
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( # tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device self.model.device
) )
output = self.model.generate(input_ids, generation_config=self.gen_cfg) output = self.model.generate(input_ids, generation_config=self.gen_cfg)
# decode output # decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
print(f"Generated {response=}") print(f"Generated {response=}")
return {"text": response} return {"text": response}