mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Serverless GPU support on banana.dev (#106)
* serverless: implement banana backend for both audio and LLM Related to monadical-sas/reflector-gpu-banana project * serverless: got llm working on banana ! * tests: fixes * serverless: fix dockerfile to use fastapi server + httpx
This commit is contained in:
@@ -1,3 +1 @@
|
||||
from .base import LLM # noqa: F401
|
||||
from . import llm_oobagooda # noqa: F401
|
||||
from . import llm_openai # noqa: F401
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
import asyncio
|
||||
from reflector.utils.retry import retry
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
|
||||
@@ -13,7 +14,7 @@ class LLM:
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def instance(cls):
|
||||
def get_instance(cls, name=None):
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
@@ -21,22 +22,19 @@ class LLM:
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
|
||||
- `LLM_URL`: url of the backend
|
||||
"""
|
||||
return cls._registry[settings.LLM_BACKEND]()
|
||||
if name is None:
|
||||
name = settings.LLM_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
return cls._registry[name]()
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
|
||||
) -> dict:
|
||||
while retry_count > 0:
|
||||
try:
|
||||
result = await self._generate(prompt=prompt, **kwargs)
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm")
|
||||
retry_count -= 1
|
||||
await asyncio.sleep(retry_interval)
|
||||
|
||||
if retry_count == 0:
|
||||
raise Exception("Failed to call llm after retrying")
|
||||
async def generate(self, prompt: str, **kwargs) -> dict:
|
||||
try:
|
||||
result = await retry(self._generate)(prompt=prompt, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
raise
|
||||
|
||||
if isinstance(result, str):
|
||||
result = self._parse_json(result)
|
||||
|
||||
41
server/reflector/llm/llm_banana.py
Normal file
41
server/reflector/llm/llm_banana.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
import httpx
|
||||
|
||||
|
||||
class BananaLLM(LLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
text = text[len(prompt) :] # remove prompt
|
||||
return text
|
||||
|
||||
|
||||
LLM.register("banana", BananaLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
llm = BananaLLM()
|
||||
result = await llm.generate("Hello, my name is")
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user