mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
Merge branch 'main' into feat-user-auth-fief
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from time import monotonic
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
from time import monotonic
|
||||
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class LLM:
|
||||
@@ -20,7 +21,7 @@ class LLM:
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
|
||||
- `LLM_URL`: url of the backend
|
||||
"""
|
||||
if name is None:
|
||||
@@ -44,10 +45,16 @@ class LLM:
|
||||
async def _warmup(self, logger: reflector_logger):
|
||||
pass
|
||||
|
||||
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
logger: reflector_logger,
|
||||
schema: dict | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
logger.info("LLM generate", prompt=repr(prompt))
|
||||
try:
|
||||
result = await retry(self._generate)(prompt=prompt, **kwargs)
|
||||
result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
raise
|
||||
@@ -59,7 +66,7 @@ class LLM:
|
||||
|
||||
return result
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
import httpx
|
||||
|
||||
|
||||
class BananaLLM(LLM):
|
||||
@@ -13,18 +13,22 @@ class BananaLLM(LLM):
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=300, # as per their sdk
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
text = text[len(prompt) :] # remove prompt
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
import httpx
|
||||
|
||||
|
||||
class ModalLLM(LLM):
|
||||
@@ -23,18 +23,22 @@ class ModalLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
text = text[len(prompt) :] # remove prompt
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
|
||||
@@ -48,6 +52,14 @@ if __name__ == "__main__":
|
||||
result = await llm.generate("Hello, my name is", logger=logger)
|
||||
print(result)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
|
||||
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
import httpx
|
||||
|
||||
|
||||
class OobagoodaLLM(LLM):
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
class OobaboogaLLM(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.LLM_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"prompt": prompt},
|
||||
json=json_payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
LLM.register("oobagooda", OobagoodaLLM)
|
||||
LLM.register("oobabooga", OobaboogaLLM)
|
||||
@@ -1,7 +1,7 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
import httpx
|
||||
|
||||
|
||||
class OpenAILLM(LLM):
|
||||
@@ -15,7 +15,7 @@ class OpenAILLM(LLM):
|
||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_key}",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import Transcript, TitleSummary
|
||||
from reflector.utils.retry import retry
|
||||
from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, Transcript
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class TranscriptTopicDetectorProcessor(Processor):
|
||||
@@ -15,9 +15,11 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
PROMPT = """
|
||||
### Human:
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
in three sentences.
|
||||
i) title and ii) summary.
|
||||
|
||||
For the title field, generate a short title for the given text.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
three sentences.
|
||||
|
||||
{input_text}
|
||||
|
||||
@@ -30,6 +32,13 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.get_instance()
|
||||
self.topic_detector_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
async def _warmup(self):
|
||||
await self.llm.warmup(logger=self.logger)
|
||||
@@ -52,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
text = self.transcript.text
|
||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||
prompt = self.PROMPT.format(input_text=text)
|
||||
result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger)
|
||||
result = await retry(self.llm.generate)(
|
||||
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
|
||||
)
|
||||
summary = TitleSummary(
|
||||
title=result["title"],
|
||||
summary=result["summary"],
|
||||
|
||||
@@ -55,8 +55,8 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# LLM
|
||||
# available backend: openai, banana, modal, oobagooda
|
||||
LLM_BACKEND: str = "oobagooda"
|
||||
# available backend: openai, banana, modal, oobabooga
|
||||
LLM_BACKEND: str = "oobabooga"
|
||||
|
||||
# LLM common configuration
|
||||
LLM_URL: str | None = None
|
||||
|
||||
Reference in New Issue
Block a user