Files
reflector/server/reflector/llm/llm_modal.py
Gokul Mohanarangan a98a9853be PR review comments
2023-08-17 14:42:45 +05:30

68 lines
1.9 KiB
Python

import json
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
class ModalLLM(LLM):
def __init__(self):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.llm_url = settings.LLM_URL + "/llm"
self.llm_warmup_url = settings.LLM_URL + "/warmup"
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
async def _warmup(self, logger):
async with httpx.AsyncClient() as client:
response = await client.post(
self.llm_warmup_url,
headers=self.headers,
timeout=60 * 5,
)
response.raise_for_status()
async def _generate(self, prompt: str, schema: str | None, **kwargs):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
headers=self.headers,
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
)
response.raise_for_status()
text = response.json()["text"]
if not schema:
text = text[len(prompt) :]
return text
LLM.register("modal", ModalLLM)
if __name__ == "__main__":
from reflector.logger import logger
async def main():
llm = ModalLLM()
result = await llm.generate("Hello, my name is", logger=logger)
print(result)
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
}
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
print(result)
import asyncio
asyncio.run(main())