Serverless GPU support on banana.dev (#106)

* serverless: implement banana backend for both audio and LLM

Related to monadical-sas/reflector-gpu-banana project

* serverless: got llm working on banana !

* tests: fixes

* serverless: fix dockerfile to use fastapi server + httpx
This commit is contained in:
2023-08-04 10:24:11 +02:00
committed by GitHub
parent a5ce66c299
commit d94e2911c3
17 changed files with 602 additions and 53 deletions

View File

@@ -1,3 +1 @@
from .base import LLM # noqa: F401
from . import llm_oobagooda # noqa: F401
from . import llm_openai # noqa: F401

View File

@@ -1,6 +1,7 @@
from reflector.logger import logger
from reflector.settings import settings
import asyncio
from reflector.utils.retry import retry
import importlib
import json
import re
@@ -13,7 +14,7 @@ class LLM:
cls._registry[name] = klass
@classmethod
def instance(cls):
def get_instance(cls, name=None):
"""
Return an instance depending on the settings.
Settings used:
@@ -21,22 +22,19 @@ class LLM:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
- `LLM_URL`: url of the backend
"""
return cls._registry[settings.LLM_BACKEND]()
if name is None:
name = settings.LLM_BACKEND
if name not in cls._registry:
module_name = f"reflector.llm.llm_{name}"
importlib.import_module(module_name)
return cls._registry[name]()
async def generate(
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
) -> dict:
while retry_count > 0:
try:
result = await self._generate(prompt=prompt, **kwargs)
break
except Exception:
logger.exception("Failed to call llm")
retry_count -= 1
await asyncio.sleep(retry_interval)
if retry_count == 0:
raise Exception("Failed to call llm after retrying")
async def generate(self, prompt: str, **kwargs) -> dict:
try:
result = await retry(self._generate)(prompt=prompt, **kwargs)
except Exception:
logger.exception("Failed to call llm after retrying")
raise
if isinstance(result, str):
result = self._parse_json(result)

View File

@@ -0,0 +1,41 @@
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class BananaLLM(LLM):
def __init__(self):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.headers = {
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(self, prompt: str, **kwargs):
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json={"prompt": prompt},
timeout=self.timeout,
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
return text
LLM.register("banana", BananaLLM)
if __name__ == "__main__":
async def main():
llm = BananaLLM()
result = await llm.generate("Hello, my name is")
print(result)
import asyncio
asyncio.run(main())

View File

@@ -1,19 +1,38 @@
from reflector.processors.base import Processor
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_whisper import (
AudioTranscriptWhisperProcessor,
)
from reflector.processors.types import AudioFile
from reflector.settings import settings
import importlib
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
BACKENDS = {
"whisper": AudioTranscriptWhisperProcessor,
}
BACKEND_DEFAULT = "whisper"
_registry = {}
def __init__(self, backend=None, **kwargs):
self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]()
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name):
if name not in cls._registry:
module_name = f"reflector.processors.audio_transcript_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "TRANSCRIPT_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config)
def __init__(self, **kwargs):
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
super().__init__(**kwargs)
def connect(self, processor: Processor):

View File

@@ -0,0 +1,85 @@
"""
Implementation using the GPU service from banana.
API will be a POST request to TRANSCRIPT_URL:
```json
{
"audio_url": "https://...",
"audio_ext": "wav",
"timestamp": 123.456
"language": "en"
}
```
"""
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word
from reflector.settings import settings
from reflector.storage import Storage
from reflector.utils.retry import retry
from pathlib import Path
import httpx
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
def __init__(self, banana_api_key: str, banana_model_key: str):
super().__init__()
self.transcript_url = settings.TRANSCRIPT_URL
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.storage = Storage.get_instance(
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
)
self.headers = {
"X-Banana-API-Key": banana_api_key,
"X-Banana-Model-Key": banana_model_key,
}
async def _transcript(self, data: AudioFile):
async with httpx.AsyncClient() as client:
print(f"Uploading audio {data.path.name} to S3")
url = await self._upload_file(data.path)
print(f"Try to transcribe audio {data.path.name}")
request_data = {
"audio_url": url,
"audio_ext": data.path.suffix[1:],
"timestamp": float(round(data.timestamp, 2)),
}
response = await retry(client.post)(
self.transcript_url,
json=request_data,
headers=self.headers,
timeout=self.timeout,
)
print(f"Transcript response: {response.status_code} {response.content}")
response.raise_for_status()
result = response.json()
transcript = Transcript(
text=result["text"],
words=[
Word(text=word["text"], start=word["start"], end=word["end"])
for word in result["words"]
],
)
# remove audio file from S3
await self._delete_file(data.path)
return transcript
@retry
async def _upload_file(self, path: Path) -> str:
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
return upload_result.url
@retry
async def _delete_file(self, path: Path):
await self.storage.delete_file(path.name)
return True
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)

View File

@@ -1,4 +1,5 @@
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word
from faster_whisper import WhisperModel
@@ -40,3 +41,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
)
return transcript
AudioTranscriptAutoProcessor.register("whisper", AudioTranscriptWhisperProcessor)

View File

@@ -28,7 +28,7 @@ class TranscriptTopicDetectorProcessor(Processor):
super().__init__(**kwargs)
self.transcript = None
self.min_transcript_length = min_transcript_length
self.llm = LLM.instance()
self.llm = LLM.get_instance()
async def _push(self, data: Transcript):
if self.transcript is None:

View File

@@ -26,8 +26,29 @@ class Settings(BaseSettings):
AUDIO_SAMPLING_WIDTH: int = 2
AUDIO_BUFFER_SIZE: int = 256 * 960
# Audio Transcription
# backends: whisper, banana
TRANSCRIPT_BACKEND: str = "whisper"
TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90
# Audio transcription banana.dev configuration
TRANSCRIPT_BANANA_API_KEY: str | None = None
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
# Audio transcription storage
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
# Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# LLM
LLM_BACKEND: str = "oobagooda"
# LLM common configuration
LLM_URL: str | None = None
LLM_HOST: str = "localhost"
LLM_PORT: int = 7860
@@ -38,11 +59,9 @@ class Settings(BaseSettings):
LLM_MAX_TOKENS: int = 1024
LLM_TEMPERATURE: float = 0.7
# Storage
STORAGE_BACKEND: str = "aws"
STORAGE_AWS_ACCESS_KEY: str = ""
STORAGE_AWS_SECRET_KEY: str = ""
STORAGE_AWS_BUCKET: str = ""
# LLM Banana configuration
LLM_BANANA_API_KEY: str | None = None
LLM_BANANA_MODEL_KEY: str | None = None
# Sentry
SENTRY_DSN: str | None = None

View File

@@ -0,0 +1 @@
from .base import Storage # noqa

View File

@@ -0,0 +1,47 @@
from pydantic import BaseModel
from reflector.settings import settings
import importlib
class FileResult(BaseModel):
filename: str
url: str
class Storage:
_registry = {}
CONFIG_SETTINGS = []
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name, settings_prefix=""):
if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config)
async def put_file(self, filename: str, data: bytes) -> FileResult:
return await self._put_file(filename, data)
async def _put_file(self, filename: str, data: bytes) -> FileResult:
raise NotImplementedError
async def delete_file(self, filename: str):
return await self._delete_file(filename)
async def _delete_file(self, filename: str):
raise NotImplementedError

View File

@@ -0,0 +1,67 @@
import aioboto3
from reflector.storage.base import Storage, FileResult
from reflector.logger import logger
class AwsStorage(Storage):
def __init__(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_bucket_name: str,
aws_region: str,
):
if not aws_access_key_id:
raise ValueError("Storage `aws_storage` require `aws_access_key_id`")
if not aws_secret_access_key:
raise ValueError("Storage `aws_storage` require `aws_secret_access_key`")
if not aws_bucket_name:
raise ValueError("Storage `aws_storage` require `aws_bucket_name`")
if not aws_region:
raise ValueError("Storage `aws_storage` require `aws_region`")
super().__init__()
self.aws_bucket_name = aws_bucket_name
self.aws_folder = ""
if "/" in aws_bucket_name:
self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1)
self.session = aioboto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region,
)
self.base_url = f"https://{aws_bucket_name}.s3.amazonaws.com/"
async def _put_file(self, filename: str, data: bytes) -> FileResult:
bucket = self.aws_bucket_name
folder = self.aws_folder
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
await client.put_object(
Bucket=bucket,
Key=s3filename,
Body=data,
)
presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": s3filename},
ExpiresIn=3600,
)
return FileResult(
filename=filename,
url=presigned_url,
)
async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name
folder = self.aws_folder
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
await client.delete_object(Bucket=bucket, Key=s3filename)
Storage.register("aws", AwsStorage)

View File

@@ -12,7 +12,7 @@ from reflector.processors import (
import asyncio
async def process_audio_file(filename, event_callback):
async def process_audio_file(filename, event_callback, only_transcript=False):
async def on_transcript(data):
await event_callback("transcript", data)
@@ -22,15 +22,21 @@ async def process_audio_file(filename, event_callback):
async def on_summary(data):
await event_callback("summary", data)
# transcription output
pipeline = Pipeline(
# build pipeline for audio processing
processors = [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_summary),
)
]
if not only_transcript:
processors += [
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_summary),
]
# transcription output
pipeline = Pipeline(*processors)
pipeline.describe()
# start processing audio
@@ -52,6 +58,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
parser.add_argument("--only-transcript", "-t", action="store_true")
args = parser.parse_args()
async def event_callback(event, data):
@@ -62,4 +69,8 @@ if __name__ == "__main__":
elif event == "summary":
print(f"Summary: {data}")
asyncio.run(process_audio_file(args.source, event_callback))
asyncio.run(
process_audio_file(
args.source, event_callback, only_transcript=args.only_transcript
)
)

View File

@@ -0,0 +1,29 @@
from reflector.logger import logger
import asyncio
def retry(fn):
async def decorated(*args, **kwargs):
retry_max = kwargs.pop("retry_max", 5)
retry_delay = kwargs.pop("retry_delay", 2)
retry_ignore_exc_types = kwargs.pop("retry_ignore_exc_types", ())
result = None
attempt = 0
last_exception = None
for attempt in range(retry_max):
try:
result = await fn(*args, **kwargs)
if result:
return result
except retry_ignore_exc_types as e:
last_exception = e
logger.debug(
f"Retrying {fn} - in {retry_delay} seconds "
f"- attempt {attempt + 1}/{retry_max}"
)
await asyncio.sleep(retry_delay)
if last_exception is not None:
raise type(last_exception) from last_exception
return result
return decorated