From 2e48f89fdc0a8e9df82fe36864e3fe53935695e5 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 09:33:59 +0530 Subject: [PATCH] add comments and log --- server/gpu/modal/reflector_llm.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 2f96e330..21306763 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -113,7 +113,9 @@ class LLM: @method() def generate(self, prompt: str, schema: str = None): print(f"Generate {prompt=}") + # If a schema is given, conform to schema if schema: + print(f"Schema {schema=}") import ast import jsonformer @@ -123,16 +125,17 @@ class LLM: prompt=prompt, max_string_token_length=self.gen_cfg.max_new_tokens) response = jsonformer_llm() - print(f"Generated {response=}") - return {"text": response} + else: + # If no schema, perform prompt only generation - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( + # tokenize prompt + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( self.model.device - ) - output = self.model.generate(input_ids, generation_config=self.gen_cfg) + ) + output = self.model.generate(input_ids, generation_config=self.gen_cfg) - # decode output - response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) + # decode output + response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) print(f"Generated {response=}") return {"text": response}