mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
add comments and log
This commit is contained in:
@@ -113,7 +113,9 @@ class LLM:
|
||||
@method()
|
||||
def generate(self, prompt: str, schema: str = None):
|
||||
print(f"Generate {prompt=}")
|
||||
# If a schema is given, conform to schema
|
||||
if schema:
|
||||
print(f"Schema {schema=}")
|
||||
import ast
|
||||
import jsonformer
|
||||
|
||||
@@ -123,16 +125,17 @@ class LLM:
|
||||
prompt=prompt,
|
||||
max_string_token_length=self.gen_cfg.max_new_tokens)
|
||||
response = jsonformer_llm()
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
else:
|
||||
# If no schema, perform prompt only generation
|
||||
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||
# decode output
|
||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user