From 2d686da15c98e83731602dbc878e67633baf0316 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 21:26:20 +0530 Subject: [PATCH] pass schema as dict --- server/gpu/modal/reflector_llm.py | 7 +++++-- server/reflector/llm/llm_banana.py | 4 +--- server/reflector/llm/llm_modal.py | 4 +--- server/reflector/llm/llm_oobabooga.py | 4 +--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 10cf4772..fd8a4aae 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -172,13 +172,16 @@ def web(): class LLMRequest(BaseModel): prompt: str - schema: Optional[str] = None + schema: Optional[dict] = None @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema) + if req.schema: + func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema)) + else: + func = llmstub.generate.spawn(prompt=req.prompt) result = func.get() return result diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index 07119613..56fc0e69 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -18,7 +16,7 @@ class BananaLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index ea9ff152..ce0de02a 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -28,7 +26,7 @@ class ModalLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, diff --git a/server/reflector/llm/llm_oobabooga.py b/server/reflector/llm/llm_oobabooga.py index 6c5a68ec..411014c5 100644 --- a/server/reflector/llm/llm_oobabooga.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -9,7 +7,7 @@ class OobaboogaLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL,