server: introduce LLM backends

This commit is contained in:
Mathieu Virbel
2023-08-01 14:23:34 +02:00
parent 224afc6f28
commit 42f1442e56
5 changed files with 129 additions and 3 deletions

View File

@@ -0,0 +1,58 @@
from reflector.logger import logger
from reflector.settings import settings
import asyncio
import json
class LLM:
_registry = {}
@classmethod
def register(cls, name, klass):
cls._registry[name] = klass
@classmethod
def instance(cls):
"""
Return an instance depending on the settings.
Settings used:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
- `LLM_URL`: url of the backend
"""
return cls._registry[settings.LLM_BACKEND]()
async def generate(
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
) -> dict:
while retry_count > 0:
try:
result = await self._generate(prompt=prompt, **kwargs)
break
except Exception:
logger.exception("Failed to call llm")
retry_count -= 1
await asyncio.sleep(retry_interval)
if retry_count == 0:
raise Exception("Failed to call llm after retrying")
if isinstance(result, str):
result = self._parse_json(result)
return result
async def _generate(self, prompt: str, **kwargs) -> str:
raise NotImplementedError
def _parse_json(self, result: str) -> dict:
result = result.strip()
# try detecting code block if exist
if result.startswith("```json\n") and result.endswith("```"):
result = result[8:-3]
elif result.startswith("```\n") and result.endswith("```"):
result = result[4:-3]
print(">>>", result)
return json.loads(result.strip())