make schema optional argument

This commit is contained in:
Gokul Mohanarangan
2023-08-17 09:23:14 +05:30
parent 5f79e04642
commit eb13a7bd64
9 changed files with 48 additions and 37 deletions

View File

@@ -1,10 +1,12 @@
from reflector.settings import settings
from reflector.utils.retry import retry
from reflector.logger import logger as reflector_logger
from time import monotonic
import importlib
import json
import re
from time import monotonic
from typing import Union
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
class LLM:
@@ -44,10 +46,12 @@ class LLM:
async def _warmup(self, logger: reflector_logger):
pass
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict:
async def generate(
self, prompt: str, logger: reflector_logger, schema: str = None, **kwargs
) -> dict:
logger.info("LLM generate", prompt=repr(prompt))
try:
result = await retry(self._generate)(prompt=prompt, **kwargs)
result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs)
except Exception:
logger.exception("Failed to call llm after retrying")
raise
@@ -59,7 +63,7 @@ class LLM:
return result
async def _generate(self, prompt: str, **kwargs) -> str:
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str:
raise NotImplementedError
def _parse_json(self, result: str) -> dict:

View File

@@ -1,4 +1,5 @@
import json
from typing import Union
import httpx
from reflector.llm.base import LLM
@@ -15,10 +16,10 @@ class BananaLLM(LLM):
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
@@ -29,7 +30,7 @@ class BananaLLM(LLM):
)
response.raise_for_status()
text = response.json()["text"]
if "schema" not in json_payload:
if not schema:
text = text[len(prompt) :]
return text

View File

@@ -1,4 +1,5 @@
import json
from typing import Union
import httpx
from reflector.llm.base import LLM
@@ -25,10 +26,10 @@ class ModalLLM(LLM):
)
response.raise_for_status()
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
@@ -39,7 +40,7 @@ class ModalLLM(LLM):
)
response.raise_for_status()
text = response.json()["text"]
if "schema" not in json_payload:
if not schema:
text = text[len(prompt) :]
return text
@@ -54,13 +55,12 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger)
print(result)
kwargs = {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
}
result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger)
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
print(result)
import asyncio

View File

@@ -1,4 +1,5 @@
import json
from typing import Union
import httpx
from reflector.llm.base import LLM
@@ -6,10 +7,10 @@ from reflector.settings import settings
class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await client.post(
settings.LLM_URL,

View File

@@ -1,7 +1,9 @@
from typing import Union
import httpx
from reflector.llm.base import LLM
from reflector.logger import logger
from reflector.settings import settings
import httpx
class OpenAILLM(LLM):
@@ -15,7 +17,7 @@ class OpenAILLM(LLM):
self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(self, prompt: str, **kwargs) -> str:
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}",

View File

@@ -14,7 +14,6 @@ class TranscriptTopicDetectorProcessor(Processor):
PROMPT = """
### Human:
Generate information based on the given schema:
For the title field, generate a short title for the given text.
For the summary field, summarize the given text in a maximum of
@@ -62,7 +61,7 @@ class TranscriptTopicDetectorProcessor(Processor):
self.logger.info(f"Topic detector got {len(text)} length transcript")
prompt = self.PROMPT.format(input_text=text)
result = await retry(self.llm.generate)(
prompt=prompt, kwargs=self.kwargs, logger=self.logger
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
)
summary = TitleSummary(
title=result["title"],