make schema optional for all LLMs

This commit is contained in:
Gokul Mohanarangan
2023-08-16 22:37:20 +05:30
parent 976c0ab9a8
commit 5f79e04642
8 changed files with 79 additions and 37 deletions

View File

@@ -5,8 +5,9 @@ Reflector GPU backend - LLM
""" """
import os import os
from modal import Image, method, Stub, asgi_app, Secret
from modal import asgi_app, Image, method, Secret, Stub
from pydantic.typing import Optional
# LLM # LLM
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
@@ -100,13 +101,6 @@ class LLM:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.gen_cfg = gen_cfg self.gen_cfg = gen_cfg
self.json_schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
def __exit__(self, *args): def __exit__(self, *args):
print("Exit llm") print("Exit llm")
@@ -117,19 +111,30 @@ class LLM:
return {"status": "ok"} return {"status": "ok"}
@method() @method()
def generate(self, prompt: str): def generate(self, prompt: str, schema: str = None):
print(f"Generate {prompt=}") print(f"Generate {prompt=}")
import jsonformer if schema:
import json import ast
import jsonformer
jsonformer_llm = jsonformer.Jsonformer(model=self.model, jsonformer_llm = jsonformer.Jsonformer(model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
json_schema=self.json_schema, json_schema=ast.literal_eval(schema),
prompt=prompt, prompt=prompt,
max_string_token_length=self.gen_cfg.max_new_tokens) max_string_token_length=self.gen_cfg.max_new_tokens)
response = jsonformer_llm() response = jsonformer_llm()
print(f"Generated {response=}")
return {"text": response}
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
print(f"Generated {response=}") print(f"Generated {response=}")
return {"text": json.dumps(response)} return {"text": response}
# ------------------------------------------------------------------- # -------------------------------------------------------------------
@@ -165,12 +170,13 @@ def web():
class LLMRequest(BaseModel): class LLMRequest(BaseModel):
prompt: str prompt: str
schema: Optional[str] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)]) @app.post("/llm", dependencies=[Depends(apikey_auth)])
async def llm( async def llm(
req: LLMRequest, req: LLMRequest,
): ):
func = llmstub.generate.spawn(prompt=req.prompt) func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema)
result = func.get() result = func.get()
return result return result

View File

@@ -20,7 +20,7 @@ class LLM:
Return an instance depending on the settings. Return an instance depending on the settings.
Settings used: Settings used:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda` - `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
- `LLM_URL`: url of the backend - `LLM_URL`: url of the backend
""" """
if name is None: if name is None:

View File

@@ -1,7 +1,9 @@
import json
import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
import httpx
class BananaLLM(LLM): class BananaLLM(LLM):
@@ -14,17 +16,21 @@ class BananaLLM(LLM):
} }
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await retry(client.post)( response = await retry(client.post)(
settings.LLM_URL, settings.LLM_URL,
headers=self.headers, headers=self.headers,
json={"prompt": prompt}, json=json_payload,
timeout=self.timeout, timeout=self.timeout,
retry_timeout=300, # as per their sdk retry_timeout=300, # as per their sdk
) )
response.raise_for_status() response.raise_for_status()
text = response.json()["text"] text = response.json()["text"]
text = text[len(prompt) :] # remove prompt if "schema" not in json_payload:
text = text[len(prompt) :]
return text return text

View File

@@ -1,7 +1,9 @@
import json
import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
import httpx
class ModalLLM(LLM): class ModalLLM(LLM):
@@ -24,17 +26,21 @@ class ModalLLM(LLM):
response.raise_for_status() response.raise_for_status()
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await retry(client.post)( response = await retry(client.post)(
self.llm_url, self.llm_url,
headers=self.headers, headers=self.headers,
json={"prompt": prompt}, json=json_payload,
timeout=self.timeout, timeout=self.timeout,
retry_timeout=60 * 5, retry_timeout=60 * 5,
) )
response.raise_for_status() response.raise_for_status()
text = response.json()["text"] text = response.json()["text"]
text = text[len(prompt) :] # remove prompt if "schema" not in json_payload:
text = text[len(prompt) :]
return text return text
@@ -48,6 +54,15 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger) result = await llm.generate("Hello, my name is", logger=logger)
print(result) print(result)
kwargs = {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
}
result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger)
print(result)
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,18 +1,23 @@
import json
import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
import httpx
class OobagoodaLLM(LLM): class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
settings.LLM_URL, settings.LLM_URL,
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={"prompt": prompt}, json=json_payload,
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
LLM.register("oobagooda", OobagoodaLLM) LLM.register("oobabooga", OobaboogaLLM)

View File

@@ -1,7 +1,7 @@
from reflector.processors.base import Processor
from reflector.processors.types import Transcript, TitleSummary
from reflector.utils.retry import retry
from reflector.llm import LLM from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, Transcript
from reflector.utils.retry import retry
class TranscriptTopicDetectorProcessor(Processor): class TranscriptTopicDetectorProcessor(Processor):
@@ -31,6 +31,14 @@ class TranscriptTopicDetectorProcessor(Processor):
self.transcript = None self.transcript = None
self.min_transcript_length = min_transcript_length self.min_transcript_length = min_transcript_length
self.llm = LLM.get_instance() self.llm = LLM.get_instance()
self.topic_detector_schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
self.kwargs = {"schema": self.topic_detector_schema}
async def _warmup(self): async def _warmup(self):
await self.llm.warmup(logger=self.logger) await self.llm.warmup(logger=self.logger)
@@ -53,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor):
text = self.transcript.text text = self.transcript.text
self.logger.info(f"Topic detector got {len(text)} length transcript") self.logger.info(f"Topic detector got {len(text)} length transcript")
prompt = self.PROMPT.format(input_text=text) prompt = self.PROMPT.format(input_text=text)
result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger) result = await retry(self.llm.generate)(
prompt=prompt, kwargs=self.kwargs, logger=self.logger
)
summary = TitleSummary( summary = TitleSummary(
title=result["title"], title=result["title"],
summary=result["summary"], summary=result["summary"],

View File

@@ -41,7 +41,7 @@ model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=1
# LLM # LLM
LLM_URL = settings.LLM_URL LLM_URL = settings.LLM_URL
if not LLM_URL: if not LLM_URL:
assert settings.LLM_BACKEND == "oobagooda" assert settings.LLM_BACKEND == "oobabooga"
LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate" LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate"
logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}") logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}")

View File

@@ -52,8 +52,8 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# LLM # LLM
# available backend: openai, banana, modal, oobagooda # available backend: openai, banana, modal, oobabooga
LLM_BACKEND: str = "oobagooda" LLM_BACKEND: str = "oobabooga"
# LLM common configuration # LLM common configuration
LLM_URL: str | None = None LLM_URL: str | None = None