make schema optional argument

This commit is contained in:
Gokul Mohanarangan
2023-08-17 09:23:14 +05:30
parent 5f79e04642
commit eb13a7bd64
9 changed files with 48 additions and 37 deletions

View File

@@ -1,4 +1,5 @@
import json
from typing import Union
import httpx
from reflector.llm.base import LLM
@@ -25,10 +26,10 @@ class ModalLLM(LLM):
)
response.raise_for_status()
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
@@ -39,7 +40,7 @@ class ModalLLM(LLM):
)
response.raise_for_status()
text = response.json()["text"]
if "schema" not in json_payload:
if not schema:
text = text[len(prompt) :]
return text
@@ -54,13 +55,12 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger)
print(result)
kwargs = {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
}
result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger)
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
print(result)
import asyncio