From 5f79e04642196c9a7b970f6ef391692b2c87b07f Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Wed, 16 Aug 2023 22:37:20 +0530 Subject: [PATCH] make schema optional for all LLMs --- server/gpu/modal/reflector_llm.py | 44 +++++++++++-------- server/reflector/llm/base.py | 2 +- server/reflector/llm/llm_banana.py | 12 +++-- server/reflector/llm/llm_modal.py | 21 +++++++-- server/reflector/llm/llm_oobagooda.py | 13 ++++-- .../processors/transcript_topic_detector.py | 18 ++++++-- server/reflector/server.py | 2 +- server/reflector/settings.py | 4 +- 8 files changed, 79 insertions(+), 37 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 315ff785..d83d5036 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -5,8 +5,9 @@ Reflector GPU backend - LLM """ import os -from modal import Image, method, Stub, asgi_app, Secret +from modal import asgi_app, Image, method, Secret, Stub +from pydantic.typing import Optional # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" @@ -100,13 +101,6 @@ class LLM: self.model = model self.tokenizer = tokenizer self.gen_cfg = gen_cfg - self.json_schema = { - "type": "object", - "properties": { - "title": {"type": "string"}, - "summary": {"type": "string"}, - }, - } def __exit__(self, *args): print("Exit llm") @@ -117,19 +111,30 @@ class LLM: return {"status": "ok"} @method() - def generate(self, prompt: str): + def generate(self, prompt: str, schema: str = None): print(f"Generate {prompt=}") - import jsonformer - import json + if schema: + import ast + import jsonformer - jsonformer_llm = jsonformer.Jsonformer(model=self.model, - tokenizer=self.tokenizer, - json_schema=self.json_schema, - prompt=prompt, - max_string_token_length=self.gen_cfg.max_new_tokens) - response = jsonformer_llm() + jsonformer_llm = jsonformer.Jsonformer(model=self.model, + tokenizer=self.tokenizer, + json_schema=ast.literal_eval(schema), + prompt=prompt, + max_string_token_length=self.gen_cfg.max_new_tokens) + response = jsonformer_llm() + print(f"Generated {response=}") + return {"text": response} + + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + output = self.model.generate(input_ids, generation_config=self.gen_cfg) + + # decode output + response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) print(f"Generated {response=}") - return {"text": json.dumps(response)} + return {"text": response} # ------------------------------------------------------------------- @@ -165,12 +170,13 @@ def web(): class LLMRequest(BaseModel): prompt: str + schema: Optional[str] = None @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - func = llmstub.generate.spawn(prompt=req.prompt) + func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema) result = func.get() return result diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index e528a3e6..5e86b553 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -20,7 +20,7 @@ class LLM: Return an instance depending on the settings. Settings used: - - `LLM_BACKEND`: key of the backend, defaults to `oobagooda` + - `LLM_BACKEND`: key of the backend, defaults to `oobabooga` - `LLM_URL`: url of the backend """ if name is None: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index d6a0fa07..bdaf091f 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,7 +1,9 @@ +import json + +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class BananaLLM(LLM): @@ -14,17 +16,21 @@ class BananaLLM(LLM): } async def _generate(self, prompt: str, **kwargs): + json_payload = {"prompt": prompt} + if "schema" in kwargs: + json_payload["schema"] = json.dumps(kwargs["schema"]) async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=300, # as per their sdk ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if "schema" not in json_payload: + text = text[len(prompt) :] return text diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 7f23aa0d..692dd095 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,7 +1,9 @@ +import json + +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class ModalLLM(LLM): @@ -24,17 +26,21 @@ class ModalLLM(LLM): response.raise_for_status() async def _generate(self, prompt: str, **kwargs): + json_payload = {"prompt": prompt} + if "schema" in kwargs: + json_payload["schema"] = json.dumps(kwargs["schema"]) async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=60 * 5, ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if "schema" not in json_payload: + text = text[len(prompt) :] return text @@ -48,6 +54,15 @@ if __name__ == "__main__": result = await llm.generate("Hello, my name is", logger=logger) print(result) + kwargs = { + "schema": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + } + result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger) + print(result) + import asyncio asyncio.run(main()) diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobagooda.py index be7d8133..85306135 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobagooda.py @@ -1,18 +1,23 @@ +import json + +import httpx from reflector.llm.base import LLM from reflector.settings import settings -import httpx -class OobagoodaLLM(LLM): +class OobaboogaLLM(LLM): async def _generate(self, prompt: str, **kwargs): + json_payload = {"prompt": prompt} + if "schema" in kwargs: + json_payload["schema"] = json.dumps(kwargs["schema"]) async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, headers={"Content-Type": "application/json"}, - json={"prompt": prompt}, + json=json_payload, ) response.raise_for_status() return response.json() -LLM.register("oobagooda", OobagoodaLLM) +LLM.register("oobabooga", OobaboogaLLM) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 6e926771..430e3992 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -1,7 +1,7 @@ -from reflector.processors.base import Processor -from reflector.processors.types import Transcript, TitleSummary -from reflector.utils.retry import retry from reflector.llm import LLM +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary, Transcript +from reflector.utils.retry import retry class TranscriptTopicDetectorProcessor(Processor): @@ -31,6 +31,14 @@ class TranscriptTopicDetectorProcessor(Processor): self.transcript = None self.min_transcript_length = min_transcript_length self.llm = LLM.get_instance() + self.topic_detector_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "summary": {"type": "string"}, + }, + } + self.kwargs = {"schema": self.topic_detector_schema} async def _warmup(self): await self.llm.warmup(logger=self.logger) @@ -53,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor): text = self.transcript.text self.logger.info(f"Topic detector got {len(text)} length transcript") prompt = self.PROMPT.format(input_text=text) - result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger) + result = await retry(self.llm.generate)( + prompt=prompt, kwargs=self.kwargs, logger=self.logger + ) summary = TitleSummary( title=result["title"], summary=result["summary"], diff --git a/server/reflector/server.py b/server/reflector/server.py index 8e28b583..3b09efe4 100644 --- a/server/reflector/server.py +++ b/server/reflector/server.py @@ -41,7 +41,7 @@ model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=1 # LLM LLM_URL = settings.LLM_URL if not LLM_URL: - assert settings.LLM_BACKEND == "oobagooda" + assert settings.LLM_BACKEND == "oobabooga" LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate" logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}") diff --git a/server/reflector/settings.py b/server/reflector/settings.py index e776875b..81f817da 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -52,8 +52,8 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None # LLM - # available backend: openai, banana, modal, oobagooda - LLM_BACKEND: str = "oobagooda" + # available backend: openai, banana, modal, oobabooga + LLM_BACKEND: str = "oobabooga" # LLM common configuration LLM_URL: str | None = None