diff --git a/server/reflector/llm/__init__.py b/server/reflector/llm/__init__.py new file mode 100644 index 00000000..fddf3919 --- /dev/null +++ b/server/reflector/llm/__init__.py @@ -0,0 +1,3 @@ +from .base import LLM # noqa: F401 +from . import llm_oobagooda # noqa: F401 +from . import llm_openai # noqa: F401 diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py new file mode 100644 index 00000000..55c0de5f --- /dev/null +++ b/server/reflector/llm/base.py @@ -0,0 +1,58 @@ +from reflector.logger import logger +from reflector.settings import settings +import asyncio +import json + + +class LLM: + _registry = {} + + @classmethod + def register(cls, name, klass): + cls._registry[name] = klass + + @classmethod + def instance(cls): + """ + Return an instance depending on the settings. + Settings used: + + - `LLM_BACKEND`: key of the backend, defaults to `oobagooda` + - `LLM_URL`: url of the backend + """ + return cls._registry[settings.LLM_BACKEND]() + + async def generate( + self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs + ) -> dict: + while retry_count > 0: + try: + result = await self._generate(prompt=prompt, **kwargs) + break + except Exception: + logger.exception("Failed to call llm") + retry_count -= 1 + await asyncio.sleep(retry_interval) + + if retry_count == 0: + raise Exception("Failed to call llm after retrying") + + if isinstance(result, str): + result = self._parse_json(result) + + return result + + async def _generate(self, prompt: str, **kwargs) -> str: + raise NotImplementedError + + def _parse_json(self, result: str) -> dict: + result = result.strip() + # try detecting code block if exist + if result.startswith("```json\n") and result.endswith("```"): + result = result[8:-3] + elif result.startswith("```\n") and result.endswith("```"): + result = result[4:-3] + print(">>>", result) + return json.loads(result.strip()) + + diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobagooda.py new file mode 100644 index 00000000..be7d8133 --- /dev/null +++ b/server/reflector/llm/llm_oobagooda.py @@ -0,0 +1,18 @@ +from reflector.llm.base import LLM +from reflector.settings import settings +import httpx + + +class OobagoodaLLM(LLM): + async def _generate(self, prompt: str, **kwargs): + async with httpx.AsyncClient() as client: + response = await client.post( + settings.LLM_URL, + headers={"Content-Type": "application/json"}, + json={"prompt": prompt}, + ) + response.raise_for_status() + return response.json() + + +LLM.register("oobagooda", OobagoodaLLM) diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py new file mode 100644 index 00000000..d4c565d6 --- /dev/null +++ b/server/reflector/llm/llm_openai.py @@ -0,0 +1,44 @@ +from reflector.llm.base import LLM +from reflector.logger import logger +from reflector.settings import settings +import json +import httpx + + +class OpenAILLM(LLM): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.openai_key = settings.LLM_OPENAI_KEY + self.openai_url = settings.LLM_URL + self.openai_model = settings.LLM_OPENAI_MODEL + self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE + self.timeout = settings.LLM_TIMEOUT + self.max_tokens = settings.LLM_MAX_TOKENS + logger.info(f"LLM use openai backend at {self.openai_url}") + + async def _generate(self, prompt: str, **kwargs) -> str: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.openai_key}", + } + + logger.debug(f"LLM openai prompt: {prompt}") + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + self.openai_url, + headers=headers, + json={ + "model": self.openai_model, + "prompt": prompt, + "max_tokens": self.max_tokens, + "temperature": self.openai_temperature, + }, + ) + response.raise_for_status() + result = response.json() + logger.info(f"LLM openai result: {result}") + return result["choices"][0]["text"] + + +LLM.register("openai", OpenAILLM) diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 6bad2697..0b6f6df5 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -31,6 +31,12 @@ class Settings(BaseSettings): LLM_URL: str | None = None LLM_HOST: str = "localhost" LLM_PORT: int = 7860 + LLM_OPENAI_KEY: str | None = None + LLM_OPENAI_MODEL: str = "gpt-3.5-turbo" + LLM_OPENAI_TEMPERATURE: float = 0.7 + LLM_TIMEOUT: int = 90 + LLM_MAX_TOKENS: int = 1024 + LLM_TEMPERATURE: float = 0.7 # Storage STORAGE_BACKEND: str = "aws" @@ -38,8 +44,5 @@ class Settings(BaseSettings): STORAGE_AWS_SECRET_KEY: str = "" STORAGE_AWS_BUCKET: str = "" - # OpenAI - OPENAI_API_KEY: str = "" - settings = Settings()