mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
import json
|
|
|
|
import httpx
|
|
from reflector.llm.base import LLM
|
|
from reflector.settings import settings
|
|
from reflector.utils.retry import retry
|
|
|
|
|
|
class ModalLLM(LLM):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.timeout = settings.LLM_TIMEOUT
|
|
self.llm_url = settings.LLM_URL + "/llm"
|
|
self.llm_warmup_url = settings.LLM_URL + "/warmup"
|
|
self.headers = {
|
|
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
|
}
|
|
|
|
async def _warmup(self, logger):
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
self.llm_warmup_url,
|
|
headers=self.headers,
|
|
timeout=60 * 5,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
async def _generate(self, prompt: str, schema: str | None, **kwargs):
|
|
json_payload = {"prompt": prompt}
|
|
if schema:
|
|
json_payload["schema"] = json.dumps(schema)
|
|
async with httpx.AsyncClient() as client:
|
|
response = await retry(client.post)(
|
|
self.llm_url,
|
|
headers=self.headers,
|
|
json=json_payload,
|
|
timeout=self.timeout,
|
|
retry_timeout=60 * 5,
|
|
)
|
|
response.raise_for_status()
|
|
text = response.json()["text"]
|
|
if not schema:
|
|
text = text[len(prompt) :]
|
|
return text
|
|
|
|
|
|
LLM.register("modal", ModalLLM)
|
|
|
|
if __name__ == "__main__":
|
|
from reflector.logger import logger
|
|
|
|
async def main():
|
|
llm = ModalLLM()
|
|
result = await llm.generate("Hello, my name is", logger=logger)
|
|
print(result)
|
|
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {"name": {"type": "string"}},
|
|
}
|
|
|
|
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
|
|
print(result)
|
|
|
|
import asyncio
|
|
|
|
asyncio.run(main())
|