server: introduce LLM backends

This commit is contained in:
Mathieu Virbel
2023-08-01 14:23:34 +02:00
parent 224afc6f28
commit 42f1442e56
5 changed files with 129 additions and 3 deletions

View File

@@ -0,0 +1,3 @@
from .base import LLM # noqa: F401
from . import llm_oobagooda # noqa: F401
from . import llm_openai # noqa: F401

View File

@@ -0,0 +1,58 @@
from reflector.logger import logger
from reflector.settings import settings
import asyncio
import json
class LLM:
_registry = {}
@classmethod
def register(cls, name, klass):
cls._registry[name] = klass
@classmethod
def instance(cls):
"""
Return an instance depending on the settings.
Settings used:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
- `LLM_URL`: url of the backend
"""
return cls._registry[settings.LLM_BACKEND]()
async def generate(
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
) -> dict:
while retry_count > 0:
try:
result = await self._generate(prompt=prompt, **kwargs)
break
except Exception:
logger.exception("Failed to call llm")
retry_count -= 1
await asyncio.sleep(retry_interval)
if retry_count == 0:
raise Exception("Failed to call llm after retrying")
if isinstance(result, str):
result = self._parse_json(result)
return result
async def _generate(self, prompt: str, **kwargs) -> str:
raise NotImplementedError
def _parse_json(self, result: str) -> dict:
result = result.strip()
# try detecting code block if exist
if result.startswith("```json\n") and result.endswith("```"):
result = result[8:-3]
elif result.startswith("```\n") and result.endswith("```"):
result = result[4:-3]
print(">>>", result)
return json.loads(result.strip())

View File

@@ -0,0 +1,18 @@
from reflector.llm.base import LLM
from reflector.settings import settings
import httpx
class OobagoodaLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
async with httpx.AsyncClient() as client:
response = await client.post(
settings.LLM_URL,
headers={"Content-Type": "application/json"},
json={"prompt": prompt},
)
response.raise_for_status()
return response.json()
LLM.register("oobagooda", OobagoodaLLM)

View File

@@ -0,0 +1,44 @@
from reflector.llm.base import LLM
from reflector.logger import logger
from reflector.settings import settings
import json
import httpx
class OpenAILLM(LLM):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.openai_key = settings.LLM_OPENAI_KEY
self.openai_url = settings.LLM_URL
self.openai_model = settings.LLM_OPENAI_MODEL
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
self.timeout = settings.LLM_TIMEOUT
self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(self, prompt: str, **kwargs) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}",
}
logger.debug(f"LLM openai prompt: {prompt}")
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
self.openai_url,
headers=headers,
json={
"model": self.openai_model,
"prompt": prompt,
"max_tokens": self.max_tokens,
"temperature": self.openai_temperature,
},
)
response.raise_for_status()
result = response.json()
logger.info(f"LLM openai result: {result}")
return result["choices"][0]["text"]
LLM.register("openai", OpenAILLM)

View File

@@ -31,6 +31,12 @@ class Settings(BaseSettings):
LLM_URL: str | None = None LLM_URL: str | None = None
LLM_HOST: str = "localhost" LLM_HOST: str = "localhost"
LLM_PORT: int = 7860 LLM_PORT: int = 7860
LLM_OPENAI_KEY: str | None = None
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
LLM_OPENAI_TEMPERATURE: float = 0.7
LLM_TIMEOUT: int = 90
LLM_MAX_TOKENS: int = 1024
LLM_TEMPERATURE: float = 0.7
# Storage # Storage
STORAGE_BACKEND: str = "aws" STORAGE_BACKEND: str = "aws"
@@ -38,8 +44,5 @@ class Settings(BaseSettings):
STORAGE_AWS_SECRET_KEY: str = "" STORAGE_AWS_SECRET_KEY: str = ""
STORAGE_AWS_BUCKET: str = "" STORAGE_AWS_BUCKET: str = ""
# OpenAI
OPENAI_API_KEY: str = ""
settings = Settings() settings = Settings()