diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 10cf4772..fd8a4aae 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -172,13 +172,16 @@ def web(): class LLMRequest(BaseModel): prompt: str - schema: Optional[str] = None + schema: Optional[dict] = None @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema) + if req.schema: + func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema)) + else: + func = llmstub.generate.spawn(prompt=req.prompt) result = func.get() return result diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index 07119613..56fc0e69 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -18,7 +16,7 @@ class BananaLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index ea9ff152..ce0de02a 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -28,7 +26,7 @@ class ModalLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, diff --git a/server/reflector/llm/llm_oobabooga.py b/server/reflector/llm/llm_oobabooga.py index 6c5a68ec..411014c5 100644 --- a/server/reflector/llm/llm_oobabooga.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -9,7 +7,7 @@ class OobaboogaLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL,