pass schema as dict

This commit is contained in:
Gokul Mohanarangan
2023-08-17 21:26:20 +05:30
parent 9332870e83
commit 2d686da15c
4 changed files with 8 additions and 11 deletions

View File

@@ -172,13 +172,16 @@ def web():
class LLMRequest(BaseModel): class LLMRequest(BaseModel):
prompt: str prompt: str
schema: Optional[str] = None schema: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)]) @app.post("/llm", dependencies=[Depends(apikey_auth)])
async def llm( async def llm(
req: LLMRequest, req: LLMRequest,
): ):
func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema) if req.schema:
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema))
else:
func = llmstub.generate.spawn(prompt=req.prompt)
result = func.get() result = func.get()
return result return result

View File

@@ -1,5 +1,3 @@
import json
import httpx import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
@@ -18,7 +16,7 @@ class BananaLLM(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs): async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt} json_payload = {"prompt": prompt}
if schema: if schema:
json_payload["schema"] = json.dumps(schema) json_payload["schema"] = schema
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await retry(client.post)( response = await retry(client.post)(
settings.LLM_URL, settings.LLM_URL,

View File

@@ -1,5 +1,3 @@
import json
import httpx import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
@@ -28,7 +26,7 @@ class ModalLLM(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs): async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt} json_payload = {"prompt": prompt}
if schema: if schema:
json_payload["schema"] = json.dumps(schema) json_payload["schema"] = schema
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await retry(client.post)( response = await retry(client.post)(
self.llm_url, self.llm_url,

View File

@@ -1,5 +1,3 @@
import json
import httpx import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
@@ -9,7 +7,7 @@ class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs): async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt} json_payload = {"prompt": prompt}
if schema: if schema:
json_payload["schema"] = json.dumps(schema) json_payload["schema"] = schema
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
settings.LLM_URL, settings.LLM_URL,