mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-05 02:16:46 +00:00
Serverless GPU support on banana.dev (#106)
* serverless: implement banana backend for both audio and LLM Related to monadical-sas/reflector-gpu-banana project * serverless: got llm working on banana ! * tests: fixes * serverless: fix dockerfile to use fastapi server + httpx
This commit is contained in:
@@ -1,19 +1,38 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_whisper import (
|
||||
AudioTranscriptWhisperProcessor,
|
||||
)
|
||||
from reflector.processors.types import AudioFile
|
||||
from reflector.settings import settings
|
||||
import importlib
|
||||
|
||||
|
||||
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
BACKENDS = {
|
||||
"whisper": AudioTranscriptWhisperProcessor,
|
||||
}
|
||||
BACKEND_DEFAULT = "whisper"
|
||||
_registry = {}
|
||||
|
||||
def __init__(self, backend=None, **kwargs):
|
||||
self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]()
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name):
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "TRANSCRIPT_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
|
||||
85
server/reflector/processors/audio_transcript_banana.py
Normal file
85
server/reflector/processors/audio_transcript_banana.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Implementation using the GPU service from banana.
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_url": "https://...",
|
||||
"audio_ext": "wav",
|
||||
"timestamp": 123.456
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import Storage
|
||||
from reflector.utils.retry import retry
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
|
||||
|
||||
class AudioTranscriptBananaProcessor(AudioTranscriptProcessor):
|
||||
def __init__(self, banana_api_key: str, banana_model_key: str):
|
||||
super().__init__()
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.storage = Storage.get_instance(
|
||||
settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_"
|
||||
)
|
||||
self.headers = {
|
||||
"X-Banana-API-Key": banana_api_key,
|
||||
"X-Banana-Model-Key": banana_model_key,
|
||||
}
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with httpx.AsyncClient() as client:
|
||||
print(f"Uploading audio {data.path.name} to S3")
|
||||
url = await self._upload_file(data.path)
|
||||
|
||||
print(f"Try to transcribe audio {data.path.name}")
|
||||
request_data = {
|
||||
"audio_url": url,
|
||||
"audio_ext": data.path.suffix[1:],
|
||||
"timestamp": float(round(data.timestamp, 2)),
|
||||
}
|
||||
response = await retry(client.post)(
|
||||
self.transcript_url,
|
||||
json=request_data,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
print(f"Transcript response: {response.status_code} {response.content}")
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
transcript = Transcript(
|
||||
text=result["text"],
|
||||
words=[
|
||||
Word(text=word["text"], start=word["start"], end=word["end"])
|
||||
for word in result["words"]
|
||||
],
|
||||
)
|
||||
|
||||
# remove audio file from S3
|
||||
await self._delete_file(data.path)
|
||||
|
||||
return transcript
|
||||
|
||||
@retry
|
||||
async def _upload_file(self, path: Path) -> str:
|
||||
upload_result = await self.storage.put_file(path.name, open(path, "rb"))
|
||||
return upload_result.url
|
||||
|
||||
@retry
|
||||
async def _delete_file(self, path: Path):
|
||||
await self.storage.delete_file(path.name)
|
||||
return True
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor)
|
||||
@@ -1,4 +1,5 @@
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, Word
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
@@ -40,3 +41,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
||||
)
|
||||
|
||||
return transcript
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("whisper", AudioTranscriptWhisperProcessor)
|
||||
|
||||
@@ -28,7 +28,7 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.instance()
|
||||
self.llm = LLM.get_instance()
|
||||
|
||||
async def _push(self, data: Transcript):
|
||||
if self.transcript is None:
|
||||
|
||||
Reference in New Issue
Block a user