add comments and log

This commit is contained in:
Gokul Mohanarangan
2023-08-17 09:33:59 +05:30
parent eb13a7bd64
commit 2e48f89fdc

View File

@@ -113,7 +113,9 @@ class LLM:
@method()
def generate(self, prompt: str, schema: str = None):
print(f"Generate {prompt=}")
# If a schema is given, conform to schema
if schema:
print(f"Schema {schema=}")
import ast
import jsonformer
@@ -123,16 +125,17 @@ class LLM:
prompt=prompt,
max_string_token_length=self.gen_cfg.max_new_tokens)
response = jsonformer_llm()
print(f"Generated {response=}")
return {"text": response}
else:
# If no schema, perform prompt only generation
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
)
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
print(f"Generated {response=}")
return {"text": response}