mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Serverless GPU support on banana.dev (#106)
* serverless: implement banana backend for both audio and LLM Related to monadical-sas/reflector-gpu-banana project * serverless: got llm working on banana ! * tests: fixes * serverless: fix dockerfile to use fastapi server + httpx
This commit is contained in:
@@ -1,3 +1 @@
|
||||
from .base import LLM # noqa: F401
|
||||
from . import llm_oobagooda # noqa: F401
|
||||
from . import llm_openai # noqa: F401
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
import asyncio
|
||||
from reflector.utils.retry import retry
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
|
||||
@@ -13,7 +14,7 @@ class LLM:
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def instance(cls):
|
||||
def get_instance(cls, name=None):
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
@@ -21,22 +22,19 @@ class LLM:
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
|
||||
- `LLM_URL`: url of the backend
|
||||
"""
|
||||
return cls._registry[settings.LLM_BACKEND]()
|
||||
if name is None:
|
||||
name = settings.LLM_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
return cls._registry[name]()
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs
|
||||
) -> dict:
|
||||
while retry_count > 0:
|
||||
try:
|
||||
result = await self._generate(prompt=prompt, **kwargs)
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm")
|
||||
retry_count -= 1
|
||||
await asyncio.sleep(retry_interval)
|
||||
|
||||
if retry_count == 0:
|
||||
raise Exception("Failed to call llm after retrying")
|
||||
async def generate(self, prompt: str, **kwargs) -> dict:
|
||||
try:
|
||||
result = await retry(self._generate)(prompt=prompt, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
raise
|
||||
|
||||
if isinstance(result, str):
|
||||
result = self._parse_json(result)
|
||||
|
||||
41
server/reflector/llm/llm_banana.py
Normal file
41
server/reflector/llm/llm_banana.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
import httpx
|
||||
|
||||
|
||||
class BananaLLM(LLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": settings.LLM_BANANA_API_KEY,
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
text = text[len(prompt) :] # remove prompt
|
||||
return text
|
||||
|
||||
|
||||
LLM.register("banana", BananaLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
llm = BananaLLM()
|
||||
result = await llm.generate("Hello, my name is")
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -1,19 +1,38 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_whisper import (
|
||||
AudioTranscriptWhisperProcessor,
|
||||
)
|
||||
from reflector.processors.types import AudioFile
|
||||
from reflector.settings import settings
|
||||
import importlib
|
||||
|
||||
|
||||
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
BACKENDS = {
|
||||
"whisper": AudioTranscriptWhisperProcessor,
|
||||
}
|
||||
BACKEND_DEFAULT = "whisper"
|
||||
_registry = {}
|
||||
|
||||
def __init__(self, backend=None, **kwargs):
|
||||
self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]()
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name):
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "TRANSCRIPT_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
|
||||
85
server/reflector/processors/audio_transcript_banana.py
Normal file
85
server/reflector/processors/audio_transcript_banana.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Implementation using the GPU service from banana.
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_url": "https://...",
|
||||
"audio_ext": "wav",
|
||||
"timestamp": 123.456
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import Storage
|
||||
from reflector.utils.retry import retry
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
|
||||
|
||||
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, banana_api_key: str, banana_model_key: str):
|
||||
super().__init__()
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.storage = Storage.get_instance(
|
||||
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
|
||||
)
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": banana_api_key,
|
||||
"X-Banana-Model-Key": banana_model_key,
|
||||
}
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with httpx.AsyncClient() as client:
|
||||
print(f"Uploading audio {data.path.name} to S3")
|
||||
url = await self._upload_file(data.path)
|
||||
|
||||
print(f"Try to transcribe audio {data.path.name}")
|
||||
request_data = {
|
||||
"audio_url": url,
|
||||
"audio_ext": data.path.suffix[1:],
|
||||
"timestamp": float(round(data.timestamp, 2)),
|
||||
}
|
||||
response = await retry(client.post)(
|
||||
self.transcript_url,
|
||||
json=request_data,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
print(f"Transcript response: {response.status_code} {response.content}")
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
transcript = Transcript(
|
||||
text=result["text"],
|
||||
words=[
|
||||
Word(text=word["text"], start=word["start"], end=word["end"])
|
||||
for word in result["words"]
|
||||
],
|
||||
)
|
||||
|
||||
# remove audio file from S3
|
||||
await self._delete_file(data.path)
|
||||
|
||||
return transcript
|
||||
|
||||
@retry
|
||||
async def _upload_file(self, path: Path) -> str:
|
||||
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
|
||||
return upload_result.url
|
||||
|
||||
@retry
|
||||
async def _delete_file(self, path: Path):
|
||||
await self.storage.delete_file(path.name)
|
||||
return True
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)
|
||||
@@ -1,4 +1,5 @@
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
@@ -40,3 +41,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
||||
)
|
||||
|
||||
return transcript
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("whisper", AudioTranscriptWhisperProcessor)
|
||||
|
||||
@@ -28,7 +28,7 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.instance()
|
||||
self.llm = LLM.get_instance()
|
||||
|
||||
async def _push(self, data: Transcript):
|
||||
if self.transcript is None:
|
||||
|
||||
@@ -26,8 +26,29 @@ class Settings(BaseSettings):
|
||||
AUDIO_SAMPLING_WIDTH: int = 2
|
||||
AUDIO_BUFFER_SIZE: int = 256 * 960
|
||||
|
||||
# Audio Transcription
|
||||
# backends: whisper, banana
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
|
||||
# Audio transcription banana.dev configuration
|
||||
TRANSCRIPT_BANANA_API_KEY: str | None = None
|
||||
TRANSCRIPT_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# Audio transcription storage
|
||||
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
|
||||
|
||||
# Storage configuration for AWS
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks"
|
||||
TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1"
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# LLM
|
||||
LLM_BACKEND: str = "oobagooda"
|
||||
|
||||
# LLM common configuration
|
||||
LLM_URL: str | None = None
|
||||
LLM_HOST: str = "localhost"
|
||||
LLM_PORT: int = 7860
|
||||
@@ -38,11 +59,9 @@ class Settings(BaseSettings):
|
||||
LLM_MAX_TOKENS: int = 1024
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
|
||||
# Storage
|
||||
STORAGE_BACKEND: str = "aws"
|
||||
STORAGE_AWS_ACCESS_KEY: str = ""
|
||||
STORAGE_AWS_SECRET_KEY: str = ""
|
||||
STORAGE_AWS_BUCKET: str = ""
|
||||
# LLM Banana configuration
|
||||
LLM_BANANA_API_KEY: str | None = None
|
||||
LLM_BANANA_MODEL_KEY: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
1
server/reflector/storage/__init__.py
Normal file
1
server/reflector/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .base import Storage # noqa
|
||||
47
server/reflector/storage/base.py
Normal file
47
server/reflector/storage/base.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pydantic import BaseModel
|
||||
from reflector.settings import settings
|
||||
import importlib
|
||||
|
||||
|
||||
class FileResult(BaseModel):
|
||||
filename: str
|
||||
url: str
|
||||
|
||||
|
||||
class Storage:
|
||||
_registry = {}
|
||||
CONFIG_SETTINGS = []
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name, settings_prefix=""):
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.storage.storage_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config)
|
||||
|
||||
async def put_file(self, filename: str, data: bytes) -> FileResult:
|
||||
return await self._put_file(filename, data)
|
||||
|
||||
async def _put_file(self, filename: str, data: bytes) -> FileResult:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_file(self, filename: str):
|
||||
return await self._delete_file(filename)
|
||||
|
||||
async def _delete_file(self, filename: str):
|
||||
raise NotImplementedError
|
||||
67
server/reflector/storage/storage_aws.py
Normal file
67
server/reflector/storage/storage_aws.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import aioboto3
|
||||
from reflector.storage.base import Storage, FileResult
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
class AwsStorage(Storage):
|
||||
def __init__(
|
||||
self,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_bucket_name: str,
|
||||
aws_region: str,
|
||||
):
|
||||
if not aws_access_key_id:
|
||||
raise ValueError("Storage `aws_storage` require `aws_access_key_id`")
|
||||
if not aws_secret_access_key:
|
||||
raise ValueError("Storage `aws_storage` require `aws_secret_access_key`")
|
||||
if not aws_bucket_name:
|
||||
raise ValueError("Storage `aws_storage` require `aws_bucket_name`")
|
||||
if not aws_region:
|
||||
raise ValueError("Storage `aws_storage` require `aws_region`")
|
||||
|
||||
super().__init__()
|
||||
self.aws_bucket_name = aws_bucket_name
|
||||
self.aws_folder = ""
|
||||
if "/" in aws_bucket_name:
|
||||
self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1)
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region,
|
||||
)
|
||||
self.base_url = f"https://{aws_bucket_name}.s3.amazonaws.com/"
|
||||
|
||||
async def _put_file(self, filename: str, data: bytes) -> FileResult:
|
||||
bucket = self.aws_bucket_name
|
||||
folder = self.aws_folder
|
||||
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}")
|
||||
s3filename = f"{folder}/{filename}" if folder else filename
|
||||
async with self.session.client("s3") as client:
|
||||
await client.put_object(
|
||||
Bucket=bucket,
|
||||
Key=s3filename,
|
||||
Body=data,
|
||||
)
|
||||
|
||||
presigned_url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": bucket, "Key": s3filename},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
|
||||
return FileResult(
|
||||
filename=filename,
|
||||
url=presigned_url,
|
||||
)
|
||||
|
||||
async def _delete_file(self, filename: str):
|
||||
bucket = self.aws_bucket_name
|
||||
folder = self.aws_folder
|
||||
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}")
|
||||
s3filename = f"{folder}/{filename}" if folder else filename
|
||||
async with self.session.client("s3") as client:
|
||||
await client.delete_object(Bucket=bucket, Key=s3filename)
|
||||
|
||||
|
||||
Storage.register("aws", AwsStorage)
|
||||
@@ -12,7 +12,7 @@ from reflector.processors import (
|
||||
import asyncio
|
||||
|
||||
|
||||
async def process_audio_file(filename, event_callback):
|
||||
async def process_audio_file(filename, event_callback, only_transcript=False):
|
||||
async def on_transcript(data):
|
||||
await event_callback("transcript", data)
|
||||
|
||||
@@ -22,15 +22,21 @@ async def process_audio_file(filename, event_callback):
|
||||
async def on_summary(data):
|
||||
await event_callback("summary", data)
|
||||
|
||||
# transcription output
|
||||
pipeline = Pipeline(
|
||||
# build pipeline for audio processing
|
||||
processors = [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(callback=on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_summary),
|
||||
)
|
||||
]
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_summary),
|
||||
]
|
||||
|
||||
# transcription output
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.describe()
|
||||
|
||||
# start processing audio
|
||||
@@ -52,6 +58,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument("--only-transcript", "-t", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
async def event_callback(event, data):
|
||||
@@ -62,4 +69,8 @@ if __name__ == "__main__":
|
||||
elif event == "summary":
|
||||
print(f"Summary: {data}")
|
||||
|
||||
asyncio.run(process_audio_file(args.source, event_callback))
|
||||
asyncio.run(
|
||||
process_audio_file(
|
||||
args.source, event_callback, only_transcript=args.only_transcript
|
||||
)
|
||||
)
|
||||
|
||||
29
server/reflector/utils/retry.py
Normal file
29
server/reflector/utils/retry.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from reflector.logger import logger
|
||||
import asyncio
|
||||
|
||||
|
||||
def retry(fn):
|
||||
async def decorated(*args, **kwargs):
|
||||
retry_max = kwargs.pop("retry_max", 5)
|
||||
retry_delay = kwargs.pop("retry_delay", 2)
|
||||
retry_ignore_exc_types = kwargs.pop("retry_ignore_exc_types", ())
|
||||
result = None
|
||||
attempt = 0
|
||||
last_exception = None
|
||||
for attempt in range(retry_max):
|
||||
try:
|
||||
result = await fn(*args, **kwargs)
|
||||
if result:
|
||||
return result
|
||||
except retry_ignore_exc_types as e:
|
||||
last_exception = e
|
||||
logger.debug(
|
||||
f"Retrying {fn} - in {retry_delay} seconds "
|
||||
f"- attempt {attempt + 1}/{retry_max}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
if last_exception is not None:
|
||||
raise type(last_exception) from last_exception
|
||||
return result
|
||||
|
||||
return decorated
|
||||
Reference in New Issue
Block a user