Serverless GPU support on banana.dev (#106)

* serverless: implement banana backend for both audio and LLM

Related to monadical-sas/reflector-gpu-banana project

* serverless: got llm working on banana !

* tests: fixes

* serverless: fix dockerfile to use fastapi server + httpx
This commit is contained in:
2023-08-04 10:24:11 +02:00
committed by GitHub
parent a5ce66c299
commit d94e2911c3
17 changed files with 602 additions and 53 deletions

View File

@@ -1,3 +1 @@
from .base import LLM # noqa: F401
from . import llm_oobagooda # noqa: F401
from . import llm_openai # noqa: F401

View File

@@ -1,6 +1,7 @@
from reflector.logger import logger
from reflector.settings import settings
import asyncio
from reflector.utils.retry import retry
import importlib
import json
import re
@@ -13,7 +14,7 @@ class LLM:
cls._registry[name] = klass
@classmethod
def instance(cls):
def get_instance(cls, name=None):
"""
Return an instance depending on the settings.
Settings used:
@@ -21,22 +22,19 @@ class LLM:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
- `LLM_URL`: url of the backend
"""
return cls._registry[settings.LLM_BACKEND]()
if name is None:
name = settings.LLM_BACKEND
if name not in cls._registry:
module_name = f"reflector.llm.llm_{name}"
importlib.import_module(module_name)
return cls._registry[name]()
async def generate(
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
) -> dict:
while retry_count > 0:
try:
result = await self._generate(prompt=prompt, **kwargs)
break
except Exception:
logger.exception("Failed to call llm")
retry_count -= 1
await asyncio.sleep(retry_interval)
if retry_count == 0:
raise Exception("Failed to call llm after retrying")
async def generate(self, prompt: str, **kwargs) -> dict:
try:
result = await retry(self._generate)(prompt=prompt, **kwargs)
except Exception:
logger.exception("Failed to call llm after retrying")
raise
if isinstance(result, str):
result = self._parse_json(result)

View File

@@ -0,0 +1,41 @@
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class BananaLLM(LLM):
def __init__(self):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.headers = {
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(self, prompt: str, **kwargs):
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json={"prompt": prompt},
timeout=self.timeout,
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
return text
LLM.register("banana", BananaLLM)
if __name__ == "__main__":
async def main():
llm = BananaLLM()
result = await llm.generate("Hello, my name is")
print(result)
import asyncio
asyncio.run(main())