mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
add comments and log
This commit is contained in:
@@ -113,7 +113,9 @@ class LLM:
|
|||||||
@method()
|
@method()
|
||||||
def generate(self, prompt: str, schema: str = None):
|
def generate(self, prompt: str, schema: str = None):
|
||||||
print(f"Generate {prompt=}")
|
print(f"Generate {prompt=}")
|
||||||
|
# If a schema is given, conform to schema
|
||||||
if schema:
|
if schema:
|
||||||
|
print(f"Schema {schema=}")
|
||||||
import ast
|
import ast
|
||||||
import jsonformer
|
import jsonformer
|
||||||
|
|
||||||
@@ -123,16 +125,17 @@ class LLM:
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_string_token_length=self.gen_cfg.max_new_tokens)
|
max_string_token_length=self.gen_cfg.max_new_tokens)
|
||||||
response = jsonformer_llm()
|
response = jsonformer_llm()
|
||||||
print(f"Generated {response=}")
|
else:
|
||||||
return {"text": response}
|
# If no schema, perform prompt only generation
|
||||||
|
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
# tokenize prompt
|
||||||
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||||
self.model.device
|
self.model.device
|
||||||
)
|
)
|
||||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||||
|
|
||||||
# decode output
|
# decode output
|
||||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||||
print(f"Generated {response=}")
|
print(f"Generated {response=}")
|
||||||
return {"text": response}
|
return {"text": response}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user