mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: introduce LLM backends
This commit is contained in:
3
server/reflector/llm/__init__.py
Normal file
3
server/reflector/llm/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import LLM # noqa: F401
|
||||
from . import llm_oobagooda # noqa: F401
|
||||
from . import llm_openai # noqa: F401
|
||||
58
server/reflector/llm/base.py
Normal file
58
server/reflector/llm/base.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
|
||||
class LLM:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, klass):
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def instance(cls):
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
|
||||
- `LLM_URL`: url of the backend
|
||||
"""
|
||||
return cls._registry[settings.LLM_BACKEND]()
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
|
||||
) -> dict:
|
||||
while retry_count > 0:
|
||||
try:
|
||||
result = await self._generate(prompt=prompt, **kwargs)
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm")
|
||||
retry_count -= 1
|
||||
await asyncio.sleep(retry_interval)
|
||||
|
||||
if retry_count == 0:
|
||||
raise Exception("Failed to call llm after retrying")
|
||||
|
||||
if isinstance(result, str):
|
||||
result = self._parse_json(result)
|
||||
|
||||
return result
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
result = result.strip()
|
||||
# try detecting code block if exist
|
||||
if result.startswith("```json\n") and result.endswith("```"):
|
||||
result = result[8:-3]
|
||||
elif result.startswith("```\n") and result.endswith("```"):
|
||||
result = result[4:-3]
|
||||
print(">>>", result)
|
||||
return json.loads(result.strip())
|
||||
|
||||
|
||||
18
server/reflector/llm/llm_oobagooda.py
Normal file
18
server/reflector/llm/llm_oobagooda.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
import httpx
|
||||
|
||||
|
||||
class OobagoodaLLM(LLM):
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.LLM_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"prompt": prompt},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
LLM.register("oobagooda", OobagoodaLLM)
|
||||
44
server/reflector/llm/llm_openai.py
Normal file
44
server/reflector/llm/llm_openai.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
import json
|
||||
import httpx
|
||||
|
||||
|
||||
class OpenAILLM(LLM):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.openai_key = settings.LLM_OPENAI_KEY
|
||||
self.openai_url = settings.LLM_URL
|
||||
self.openai_model = settings.LLM_OPENAI_MODEL
|
||||
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_key}",
|
||||
}
|
||||
|
||||
logger.debug(f"LLM openai prompt: {prompt}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
self.openai_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"model": self.openai_model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.openai_temperature,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
logger.info(f"LLM openai result: {result}")
|
||||
return result["choices"][0]["text"]
|
||||
|
||||
|
||||
LLM.register("openai", OpenAILLM)
|
||||
@@ -31,6 +31,12 @@ class Settings(BaseSettings):
|
||||
LLM_URL: str | None = None
|
||||
LLM_HOST: str = "localhost"
|
||||
LLM_PORT: int = 7860
|
||||
LLM_OPENAI_KEY: str | None = None
|
||||
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
|
||||
LLM_OPENAI_TEMPERATURE: float = 0.7
|
||||
LLM_TIMEOUT: int = 90
|
||||
LLM_MAX_TOKENS: int = 1024
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
|
||||
# Storage
|
||||
STORAGE_BACKEND: str = "aws"
|
||||
@@ -38,8 +44,5 @@ class Settings(BaseSettings):
|
||||
STORAGE_AWS_SECRET_KEY: str = ""
|
||||
STORAGE_AWS_BUCKET: str = ""
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: str = ""
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
Reference in New Issue
Block a user