make schema optional for all LLMs

This commit is contained in:
Gokul Mohanarangan
2023-08-16 22:37:20 +05:30
parent 976c0ab9a8
commit 5f79e04642
8 changed files with 79 additions and 37 deletions

View File

@@ -5,8 +5,9 @@ Reflector GPU backend - LLM
"""
import os
from modal import Image, method, Stub, asgi_app, Secret
from modal import asgi_app, Image, method, Secret, Stub
from pydantic.typing import Optional
# LLM
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
@@ -100,13 +101,6 @@ class LLM:
self.model = model
self.tokenizer = tokenizer
self.gen_cfg = gen_cfg
self.json_schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
def __exit__(self, *args):
print("Exit llm")
@@ -117,19 +111,30 @@ class LLM:
return {"status": "ok"}
@method()
def generate(self, prompt: str):
def generate(self, prompt: str, schema: str = None):
print(f"Generate {prompt=}")
if schema:
import ast
import jsonformer
import json
jsonformer_llm = jsonformer.Jsonformer(model=self.model,
tokenizer=self.tokenizer,
json_schema=self.json_schema,
json_schema=ast.literal_eval(schema),
prompt=prompt,
max_string_token_length=self.gen_cfg.max_new_tokens)
response = jsonformer_llm()
print(f"Generated {response=}")
return {"text": json.dumps(response)}
return {"text": response}
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
print(f"Generated {response=}")
return {"text": response}
# -------------------------------------------------------------------
@@ -165,12 +170,13 @@ def web():
class LLMRequest(BaseModel):
prompt: str
schema: Optional[str] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)])
async def llm(
req: LLMRequest,
):
func = llmstub.generate.spawn(prompt=req.prompt)
func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema)
result = func.get()
return result

View File

@@ -20,7 +20,7 @@ class LLM:
Return an instance depending on the settings.
Settings used:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
- `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
- `LLM_URL`: url of the backend
"""
if name is None:

View File

@@ -1,7 +1,9 @@
import json
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class BananaLLM(LLM):
@@ -14,17 +16,21 @@ class BananaLLM(LLM):
}
async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
headers=self.headers,
json={"prompt": prompt},
json=json_payload,
timeout=self.timeout,
retry_timeout=300, # as per their sdk
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
if "schema" not in json_payload:
text = text[len(prompt) :]
return text

View File

@@ -1,7 +1,9 @@
import json
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
from reflector.utils.retry import retry
import httpx
class ModalLLM(LLM):
@@ -24,17 +26,21 @@ class ModalLLM(LLM):
response.raise_for_status()
async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
headers=self.headers,
json={"prompt": prompt},
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
)
response.raise_for_status()
text = response.json()["text"]
text = text[len(prompt) :] # remove prompt
if "schema" not in json_payload:
text = text[len(prompt) :]
return text
@@ -48,6 +54,15 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger)
print(result)
kwargs = {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
}
result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger)
print(result)
import asyncio
asyncio.run(main())

View File

@@ -1,18 +1,23 @@
import json
import httpx
from reflector.llm.base import LLM
from reflector.settings import settings
import httpx
class OobagoodaLLM(LLM):
class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client:
response = await client.post(
settings.LLM_URL,
headers={"Content-Type": "application/json"},
json={"prompt": prompt},
json=json_payload,
)
response.raise_for_status()
return response.json()
LLM.register("oobagooda", OobagoodaLLM)
LLM.register("oobabooga", OobaboogaLLM)

View File

@@ -1,7 +1,7 @@
from reflector.processors.base import Processor
from reflector.processors.types import Transcript, TitleSummary
from reflector.utils.retry import retry
from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, Transcript
from reflector.utils.retry import retry
class TranscriptTopicDetectorProcessor(Processor):
@@ -31,6 +31,14 @@ class TranscriptTopicDetectorProcessor(Processor):
self.transcript = None
self.min_transcript_length = min_transcript_length
self.llm = LLM.get_instance()
self.topic_detector_schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
self.kwargs = {"schema": self.topic_detector_schema}
async def _warmup(self):
await self.llm.warmup(logger=self.logger)
@@ -53,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor):
text = self.transcript.text
self.logger.info(f"Topic detector got {len(text)} length transcript")
prompt = self.PROMPT.format(input_text=text)
result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger)
result = await retry(self.llm.generate)(
prompt=prompt, kwargs=self.kwargs, logger=self.logger
)
summary = TitleSummary(
title=result["title"],
summary=result["summary"],

View File

@@ -41,7 +41,7 @@ model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=1
# LLM
LLM_URL = settings.LLM_URL
if not LLM_URL:
assert settings.LLM_BACKEND == "oobagooda"
assert settings.LLM_BACKEND == "oobabooga"
LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate"
logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}")

View File

@@ -52,8 +52,8 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# LLM
# available backend: openai, banana, modal, oobagooda
LLM_BACKEND: str = "oobagooda"
# available backend: openai, banana, modal, oobabooga
LLM_BACKEND: str = "oobabooga"
# LLM common configuration
LLM_URL: str | None = None