mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: introduce LLM backends
This commit is contained in:
3
server/reflector/llm/__init__.py
Normal file
3
server/reflector/llm/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .base import LLM # noqa: F401
|
||||||
|
from . import llm_oobagooda # noqa: F401
|
||||||
|
from . import llm_openai # noqa: F401
|
||||||
58
server/reflector/llm/base.py
Normal file
58
server/reflector/llm/base.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, klass):
|
||||||
|
cls._registry[name] = klass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def instance(cls):
|
||||||
|
"""
|
||||||
|
Return an instance depending on the settings.
|
||||||
|
Settings used:
|
||||||
|
|
||||||
|
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
|
||||||
|
- `LLM_URL`: url of the backend
|
||||||
|
"""
|
||||||
|
return cls._registry[settings.LLM_BACKEND]()
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
|
||||||
|
) -> dict:
|
||||||
|
while retry_count > 0:
|
||||||
|
try:
|
||||||
|
result = await self._generate(prompt=prompt, **kwargs)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to call llm")
|
||||||
|
retry_count -= 1
|
||||||
|
await asyncio.sleep(retry_interval)
|
||||||
|
|
||||||
|
if retry_count == 0:
|
||||||
|
raise Exception("Failed to call llm after retrying")
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
result = self._parse_json(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _parse_json(self, result: str) -> dict:
|
||||||
|
result = result.strip()
|
||||||
|
# try detecting code block if exist
|
||||||
|
if result.startswith("```json\n") and result.endswith("```"):
|
||||||
|
result = result[8:-3]
|
||||||
|
elif result.startswith("```\n") and result.endswith("```"):
|
||||||
|
result = result[4:-3]
|
||||||
|
print(">>>", result)
|
||||||
|
return json.loads(result.strip())
|
||||||
|
|
||||||
|
|
||||||
18
server/reflector/llm/llm_oobagooda.py
Normal file
18
server/reflector/llm/llm_oobagooda.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from reflector.llm.base import LLM
|
||||||
|
from reflector.settings import settings
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class OobagoodaLLM(LLM):
|
||||||
|
async def _generate(self, prompt: str, **kwargs):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
settings.LLM_URL,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json={"prompt": prompt},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
LLM.register("oobagooda", OobagoodaLLM)
|
||||||
44
server/reflector/llm/llm_openai.py
Normal file
44
server/reflector/llm/llm_openai.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from reflector.llm.base import LLM
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLM(LLM):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.openai_key = settings.LLM_OPENAI_KEY
|
||||||
|
self.openai_url = settings.LLM_URL
|
||||||
|
self.openai_model = settings.LLM_OPENAI_MODEL
|
||||||
|
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
|
||||||
|
self.timeout = settings.LLM_TIMEOUT
|
||||||
|
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||||
|
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||||
|
|
||||||
|
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.openai_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"LLM openai prompt: {prompt}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.openai_url,
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"model": self.openai_model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.openai_temperature,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
logger.info(f"LLM openai result: {result}")
|
||||||
|
return result["choices"][0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
LLM.register("openai", OpenAILLM)
|
||||||
@@ -31,6 +31,12 @@ class Settings(BaseSettings):
|
|||||||
LLM_URL: str | None = None
|
LLM_URL: str | None = None
|
||||||
LLM_HOST: str = "localhost"
|
LLM_HOST: str = "localhost"
|
||||||
LLM_PORT: int = 7860
|
LLM_PORT: int = 7860
|
||||||
|
LLM_OPENAI_KEY: str | None = None
|
||||||
|
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
|
||||||
|
LLM_OPENAI_TEMPERATURE: float = 0.7
|
||||||
|
LLM_TIMEOUT: int = 90
|
||||||
|
LLM_MAX_TOKENS: int = 1024
|
||||||
|
LLM_TEMPERATURE: float = 0.7
|
||||||
|
|
||||||
# Storage
|
# Storage
|
||||||
STORAGE_BACKEND: str = "aws"
|
STORAGE_BACKEND: str = "aws"
|
||||||
@@ -38,8 +44,5 @@ class Settings(BaseSettings):
|
|||||||
STORAGE_AWS_SECRET_KEY: str = ""
|
STORAGE_AWS_SECRET_KEY: str = ""
|
||||||
STORAGE_AWS_BUCKET: str = ""
|
STORAGE_AWS_BUCKET: str = ""
|
||||||
|
|
||||||
# OpenAI
|
|
||||||
OPENAI_API_KEY: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
Reference in New Issue
Block a user