Merge branch 'main' into feat-user-auth-fief

This commit is contained in:
2023-08-18 10:20:44 +02:00
committed by GitHub
12 changed files with 123 additions and 56 deletions

View File

@@ -1,10 +1,11 @@
from reflector.settings import settings
from reflector.utils.retry import retry
from reflector.logger import logger as reflector_logger
from time import monotonic
import importlib
import json
import re
from time import monotonic
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
class LLM:
@@ -20,7 +21,7 @@ class LLM:
Return an instance depending on the settings.
Settings used:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
- `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
- `LLM_URL`: url of the backend
"""
if name is None:
@@ -44,10 +45,16 @@ class LLM:
async def _warmup(self, logger: reflector_logger):
pass
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict:
async def generate(
self,
prompt: str,
logger: reflector_logger,
schema: dict | None = None,
**kwargs,
) -> dict:
logger.info("LLM generate", prompt=repr(prompt))
try:
result = await retry(self._generate)(prompt=prompt, **kwargs)
result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs)
except Exception:
logger.exception("Failed to call llm after retrying")
raise
@@ -59,7 +66,7 @@ class LLM:
return result
async def _generate(self, prompt: str, **kwargs) -> str:
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
raise NotImplementedError
def _parse_json(self, result: str) -> dict:

View File

@@ -1,7 +1,7 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class BananaLLM(LLM):
@@ -13,18 +13,22 @@ class BananaLLM(LLM):
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json={"prompt": prompt},
json=json_payload,
timeout=self.timeout,
retry_timeout=300, # as per their sdk
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
if not schema:
text = text[len(prompt) :]
return text

View File

@@ -1,7 +1,7 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class ModalLLM(LLM):
@@ -23,18 +23,22 @@ class ModalLLM(LLM):
)
response.raise_for_status()
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
headers=self.headers,
json={"prompt": prompt},
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
if not schema:
text = text[len(prompt) :]
return text
@@ -48,6 +52,14 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger)
print(result)
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
}
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
print(result)
import asyncio
asyncio.run(main())

View File

@@ -1,18 +1,21 @@
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
import httpx
class OobagoodaLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
async with httpx.AsyncClient() as client:
response = await client.post(
settings.LLM_URL,
headers={"Content-Type": "application/json"},
json={"prompt": prompt},
json=json_payload,
)
response.raise_for_status()
return response.json()
LLM.register("oobagooda", OobagoodaLLM)
LLM.register("oobabooga", OobaboogaLLM)

View File

@@ -1,7 +1,7 @@
import httpx
from reflector.llm.base import LLM
from reflector.logger import logger
from reflector.settings import settings
import httpx
class OpenAILLM(LLM):
@@ -15,7 +15,7 @@ class OpenAILLM(LLM):
self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(self, prompt: str, **kwargs) -> str:
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}",

View File

@@ -1,7 +1,7 @@
from reflector.processors.base import Processor
from reflector.processors.types import Transcript, TitleSummary
from reflector.utils.retry import retry
from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, Transcript
from reflector.utils.retry import retry
class TranscriptTopicDetectorProcessor(Processor):
@@ -15,9 +15,11 @@ class TranscriptTopicDetectorProcessor(Processor):
PROMPT = """
### Human:
Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
i) title and ii) summary.
For the title field, generate a short title for the given text.
For the summary field, summarize the given text in a maximum of
three sentences.
{input_text}
@@ -30,6 +32,13 @@ class TranscriptTopicDetectorProcessor(Processor):
self.transcript = None
self.min_transcript_length = min_transcript_length
self.llm = LLM.get_instance()
self.topic_detector_schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
async def _warmup(self):
await self.llm.warmup(logger=self.logger)
@@ -52,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor):
text = self.transcript.text
self.logger.info(f"Topic detector got {len(text)} length transcript")
prompt = self.PROMPT.format(input_text=text)
result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger)
result = await retry(self.llm.generate)(
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
)
summary = TitleSummary(
title=result["title"],
summary=result["summary"],

View File

@@ -55,8 +55,8 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# LLM
# available backend: openai, banana, modal, oobagooda
LLM_BACKEND: str = "oobagooda"
# available backend: openai, banana, modal, oobabooga
LLM_BACKEND: str = "oobabooga"
# LLM common configuration
LLM_URL: str | None = None