make schema optional for all LLMs

This commit is contained in:
Gokul Mohanarangan
2023-08-16 22:37:20 +05:30
parent 976c0ab9a8
commit 5f79e04642
8 changed files with 79 additions and 37 deletions

View File

@@ -20,7 +20,7 @@ class LLM:
Return an instance depending on the settings.
Settings used:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
- `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
- `LLM_URL`: url of the backend
"""
if name is None:

View File

@@ -1,7 +1,9 @@
import json
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class BananaLLM(LLM):
@@ -14,17 +16,21 @@ class BananaLLM(LLM):
}
async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json={"prompt": prompt},
json=json_payload,
timeout=self.timeout,
retry_timeout=300, # as per their sdk
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
if "schema" not in json_payload:
text = text[len(prompt) :]
return text

View File

@@ -1,7 +1,9 @@
import json
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class ModalLLM(LLM):
@@ -24,17 +26,21 @@ class ModalLLM(LLM):
response.raise_for_status()
async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
headers=self.headers,
json={"prompt": prompt},
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
if "schema" not in json_payload:
text = text[len(prompt) :]
return text
@@ -48,6 +54,15 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger)
print(result)
kwargs = {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
}
result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger)
print(result)
import asyncio
asyncio.run(main())

View File

@@ -1,18 +1,23 @@
import json
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
import httpx
class OobagoodaLLM(LLM):
class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client:
response = await client.post(
settings.LLM_URL,
headers={"Content-Type": "application/json"},
json={"prompt": prompt},
json=json_payload,
)
response.raise_for_status()
return response.json()
LLM.register("oobagooda", OobagoodaLLM)
LLM.register("oobabooga", OobaboogaLLM)