From 0cdd7037fbf4bb93dcf2d77cf9567d041d2ce71f Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Wed, 16 Aug 2023 14:03:25 +0530 Subject: [PATCH 01/14] wrap JSONFormer around LLM --- server/gpu/modal/reflector_llm.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index bf6f4cf5..315ff785 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -10,7 +10,7 @@ from modal import Image, method, Stub, asgi_app, Secret # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" -LLM_LOW_CPU_MEM_USAGE: bool = False +LLM_LOW_CPU_MEM_USAGE: bool = True LLM_TORCH_DTYPE: str = "bfloat16" LLM_MAX_NEW_TOKENS: int = 300 @@ -49,6 +49,8 @@ llm_image = ( "torch", "sentencepiece", "protobuf", + "jsonformer==0.12.0", + "accelerate==0.21.0", "einops==0.6.1", "hf-transfer~=0.1", "huggingface_hub==0.16.4", @@ -81,6 +83,7 @@ class LLM: # generation configuration print("Instance llm generation config") + # JSONFormer doesn't yet support generation configs, but keeping for future usage model.config.max_new_tokens = LLM_MAX_NEW_TOKENS gen_cfg = GenerationConfig.from_model_config(model.config) gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS @@ -97,6 +100,13 @@ class LLM: self.model = model self.tokenizer = tokenizer self.gen_cfg = gen_cfg + self.json_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "summary": {"type": "string"}, + }, + } def __exit__(self, *args): print("Exit llm") @@ -109,16 +119,17 @@ class LLM: @method() def generate(self, prompt: str): print(f"Generate {prompt=}") - # tokenize prompt - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( - self.model.device - ) - output = self.model.generate(input_ids, generation_config=self.gen_cfg) + import jsonformer + import json - # decode output - response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) + jsonformer_llm = jsonformer.Jsonformer(model=self.model, + tokenizer=self.tokenizer, + json_schema=self.json_schema, + prompt=prompt, + max_string_token_length=self.gen_cfg.max_new_tokens) + response = jsonformer_llm() print(f"Generated {response=}") - return {"text": response} + return {"text": json.dumps(response)} # ------------------------------------------------------------------- From 976c0ab9a8c926e89f6430b0f4fcedb1c181f7d7 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Wed, 16 Aug 2023 14:07:29 +0530 Subject: [PATCH 02/14] update prompt --- server/reflector/processors/transcript_topic_detector.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index b626e8a2..6e926771 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -14,10 +14,11 @@ class TranscriptTopicDetectorProcessor(Processor): PROMPT = """ ### Human: - Create a JSON object as response.The JSON object must have 2 fields: - i) title and ii) summary.For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. + Generate information based on the given schema: + + For the title field, generate a short title for the given text. + For the summary field, summarize the given text in a maximum of + three sentences. {input_text} From 5f79e04642196c9a7b970f6ef391692b2c87b07f Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Wed, 16 Aug 2023 22:37:20 +0530 Subject: [PATCH 03/14] make schema optional for all LLMs --- server/gpu/modal/reflector_llm.py | 44 +++++++++++-------- server/reflector/llm/base.py | 2 +- server/reflector/llm/llm_banana.py | 12 +++-- server/reflector/llm/llm_modal.py | 21 +++++++-- server/reflector/llm/llm_oobagooda.py | 13 ++++-- .../processors/transcript_topic_detector.py | 18 ++++++-- server/reflector/server.py | 2 +- server/reflector/settings.py | 4 +- 8 files changed, 79 insertions(+), 37 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 315ff785..d83d5036 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -5,8 +5,9 @@ Reflector GPU backend - LLM """ import os -from modal import Image, method, Stub, asgi_app, Secret +from modal import asgi_app, Image, method, Secret, Stub +from pydantic.typing import Optional # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" @@ -100,13 +101,6 @@ class LLM: self.model = model self.tokenizer = tokenizer self.gen_cfg = gen_cfg - self.json_schema = { - "type": "object", - "properties": { - "title": {"type": "string"}, - "summary": {"type": "string"}, - }, - } def __exit__(self, *args): print("Exit llm") @@ -117,19 +111,30 @@ class LLM: return {"status": "ok"} @method() - def generate(self, prompt: str): + def generate(self, prompt: str, schema: str = None): print(f"Generate {prompt=}") - import jsonformer - import json + if schema: + import ast + import jsonformer - jsonformer_llm = jsonformer.Jsonformer(model=self.model, - tokenizer=self.tokenizer, - json_schema=self.json_schema, - prompt=prompt, - max_string_token_length=self.gen_cfg.max_new_tokens) - response = jsonformer_llm() + jsonformer_llm = jsonformer.Jsonformer(model=self.model, + tokenizer=self.tokenizer, + json_schema=ast.literal_eval(schema), + prompt=prompt, + max_string_token_length=self.gen_cfg.max_new_tokens) + response = jsonformer_llm() + print(f"Generated {response=}") + return {"text": response} + + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + output = self.model.generate(input_ids, generation_config=self.gen_cfg) + + # decode output + response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) print(f"Generated {response=}") - return {"text": json.dumps(response)} + return {"text": response} # ------------------------------------------------------------------- @@ -165,12 +170,13 @@ def web(): class LLMRequest(BaseModel): prompt: str + schema: Optional[str] = None @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - func = llmstub.generate.spawn(prompt=req.prompt) + func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema) result = func.get() return result diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index e528a3e6..5e86b553 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -20,7 +20,7 @@ class LLM: Return an instance depending on the settings. Settings used: - - `LLM_BACKEND`: key of the backend, defaults to `oobagooda` + - `LLM_BACKEND`: key of the backend, defaults to `oobabooga` - `LLM_URL`: url of the backend """ if name is None: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index d6a0fa07..bdaf091f 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,7 +1,9 @@ +import json + +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class BananaLLM(LLM): @@ -14,17 +16,21 @@ class BananaLLM(LLM): } async def _generate(self, prompt: str, **kwargs): + json_payload = {"prompt": prompt} + if "schema" in kwargs: + json_payload["schema"] = json.dumps(kwargs["schema"]) async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=300, # as per their sdk ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if "schema" not in json_payload: + text = text[len(prompt) :] return text diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 7f23aa0d..692dd095 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,7 +1,9 @@ +import json + +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class ModalLLM(LLM): @@ -24,17 +26,21 @@ class ModalLLM(LLM): response.raise_for_status() async def _generate(self, prompt: str, **kwargs): + json_payload = {"prompt": prompt} + if "schema" in kwargs: + json_payload["schema"] = json.dumps(kwargs["schema"]) async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=60 * 5, ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if "schema" not in json_payload: + text = text[len(prompt) :] return text @@ -48,6 +54,15 @@ if __name__ == "__main__": result = await llm.generate("Hello, my name is", logger=logger) print(result) + kwargs = { + "schema": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + } + result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger) + print(result) + import asyncio asyncio.run(main()) diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobagooda.py index be7d8133..85306135 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobagooda.py @@ -1,18 +1,23 @@ +import json + +import httpx from reflector.llm.base import LLM from reflector.settings import settings -import httpx -class OobagoodaLLM(LLM): +class OobaboogaLLM(LLM): async def _generate(self, prompt: str, **kwargs): + json_payload = {"prompt": prompt} + if "schema" in kwargs: + json_payload["schema"] = json.dumps(kwargs["schema"]) async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, headers={"Content-Type": "application/json"}, - json={"prompt": prompt}, + json=json_payload, ) response.raise_for_status() return response.json() -LLM.register("oobagooda", OobagoodaLLM) +LLM.register("oobabooga", OobaboogaLLM) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 6e926771..430e3992 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -1,7 +1,7 @@ -from reflector.processors.base import Processor -from reflector.processors.types import Transcript, TitleSummary -from reflector.utils.retry import retry from reflector.llm import LLM +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary, Transcript +from reflector.utils.retry import retry class TranscriptTopicDetectorProcessor(Processor): @@ -31,6 +31,14 @@ class TranscriptTopicDetectorProcessor(Processor): self.transcript = None self.min_transcript_length = min_transcript_length self.llm = LLM.get_instance() + self.topic_detector_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "summary": {"type": "string"}, + }, + } + self.kwargs = {"schema": self.topic_detector_schema} async def _warmup(self): await self.llm.warmup(logger=self.logger) @@ -53,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor): text = self.transcript.text self.logger.info(f"Topic detector got {len(text)} length transcript") prompt = self.PROMPT.format(input_text=text) - result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger) + result = await retry(self.llm.generate)( + prompt=prompt, kwargs=self.kwargs, logger=self.logger + ) summary = TitleSummary( title=result["title"], summary=result["summary"], diff --git a/server/reflector/server.py b/server/reflector/server.py index 8e28b583..3b09efe4 100644 --- a/server/reflector/server.py +++ b/server/reflector/server.py @@ -41,7 +41,7 @@ model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=1 # LLM LLM_URL = settings.LLM_URL if not LLM_URL: - assert settings.LLM_BACKEND == "oobagooda" + assert settings.LLM_BACKEND == "oobabooga" LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate" logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}") diff --git a/server/reflector/settings.py b/server/reflector/settings.py index e776875b..81f817da 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -52,8 +52,8 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None # LLM - # available backend: openai, banana, modal, oobagooda - LLM_BACKEND: str = "oobagooda" + # available backend: openai, banana, modal, oobabooga + LLM_BACKEND: str = "oobabooga" # LLM common configuration LLM_URL: str | None = None From eb13a7bd64e83e5cadbd2151a34af3d733aaebd0 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 09:23:14 +0530 Subject: [PATCH 04/14] make schema optional argument --- server/gpu/modal/reflector_llm.py | 2 +- server/reflector/llm/base.py | 18 ++++++++++------- server/reflector/llm/llm_banana.py | 9 +++++---- server/reflector/llm/llm_modal.py | 20 +++++++++---------- server/reflector/llm/llm_oobagooda.py | 7 ++++--- server/reflector/llm/llm_openai.py | 6 ++++-- .../processors/transcript_topic_detector.py | 3 +-- server/tests/test_processors_pipeline.py | 5 ++++- server/tests/test_transcripts_rtc_ws.py | 15 +++++++------- 9 files changed, 48 insertions(+), 37 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index d83d5036..2f96e330 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -5,9 +5,9 @@ Reflector GPU backend - LLM """ import os +from typing import Optional from modal import asgi_app, Image, method, Secret, Stub -from pydantic.typing import Optional # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 5e86b553..fddf185d 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -1,10 +1,12 @@ -from reflector.settings import settings -from reflector.utils.retry import retry -from reflector.logger import logger as reflector_logger -from time import monotonic import importlib import json import re +from time import monotonic +from typing import Union + +from reflector.logger import logger as reflector_logger +from reflector.settings import settings +from reflector.utils.retry import retry class LLM: @@ -44,10 +46,12 @@ class LLM: async def _warmup(self, logger: reflector_logger): pass - async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict: + async def generate( + self, prompt: str, logger: reflector_logger, schema: str = None, **kwargs + ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: - result = await retry(self._generate)(prompt=prompt, **kwargs) + result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs) except Exception: logger.exception("Failed to call llm after retrying") raise @@ -59,7 +63,7 @@ class LLM: return result - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index bdaf091f..473769cc 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,4 +1,5 @@ import json +from typing import Union import httpx from reflector.llm.base import LLM @@ -15,10 +16,10 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): json_payload = {"prompt": prompt} - if "schema" in kwargs: - json_payload["schema"] = json.dumps(kwargs["schema"]) + if schema: + json_payload["schema"] = json.dumps(schema) async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, @@ -29,7 +30,7 @@ class BananaLLM(LLM): ) response.raise_for_status() text = response.json()["text"] - if "schema" not in json_payload: + if not schema: text = text[len(prompt) :] return text diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 692dd095..c1fb856b 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,4 +1,5 @@ import json +from typing import Union import httpx from reflector.llm.base import LLM @@ -25,10 +26,10 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): json_payload = {"prompt": prompt} - if "schema" in kwargs: - json_payload["schema"] = json.dumps(kwargs["schema"]) + if schema: + json_payload["schema"] = json.dumps(schema) async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, @@ -39,7 +40,7 @@ class ModalLLM(LLM): ) response.raise_for_status() text = response.json()["text"] - if "schema" not in json_payload: + if not schema: text = text[len(prompt) :] return text @@ -54,13 +55,12 @@ if __name__ == "__main__": result = await llm.generate("Hello, my name is", logger=logger) print(result) - kwargs = { - "schema": { - "type": "object", - "properties": {"name": {"type": "string"}}, - } + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, } - result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger) + + result = await llm.generate("Hello, my name is", schema=schema, logger=logger) print(result) import asyncio diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobagooda.py index 85306135..0ceb442d 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobagooda.py @@ -1,4 +1,5 @@ import json +from typing import Union import httpx from reflector.llm.base import LLM @@ -6,10 +7,10 @@ from reflector.settings import settings class OobaboogaLLM(LLM): - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): json_payload = {"prompt": prompt} - if "schema" in kwargs: - json_payload["schema"] = json.dumps(kwargs["schema"]) + if schema: + json_payload["schema"] = json.dumps(schema) async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index dd438704..9a74e03c 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,7 +1,9 @@ +from typing import Union + +import httpx from reflector.llm.base import LLM from reflector.logger import logger from reflector.settings import settings -import httpx class OpenAILLM(LLM): @@ -15,7 +17,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 430e3992..9ae21a72 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -14,7 +14,6 @@ class TranscriptTopicDetectorProcessor(Processor): PROMPT = """ ### Human: - Generate information based on the given schema: For the title field, generate a short title for the given text. For the summary field, summarize the given text in a maximum of @@ -62,7 +61,7 @@ class TranscriptTopicDetectorProcessor(Processor): self.logger.info(f"Topic detector got {len(text)} length transcript") prompt = self.PROMPT.format(input_text=text) result = await retry(self.llm.generate)( - prompt=prompt, kwargs=self.kwargs, logger=self.logger + prompt=prompt, schema=self.topic_detector_schema, logger=self.logger ) summary = TitleSummary( title=result["title"], diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 95c296de..56cac96e 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -9,13 +9,16 @@ async def test_basic_process(event_loop): from reflector.settings import settings from reflector.llm.base import LLM from pathlib import Path + from typing import Union # use an LLM test backend settings.LLM_BACKEND = "test" settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate( + self, prompt: str, schema: Union[str | None], **kwargs + ) -> str: return { "title": "TITLE", "summary": "SUMMARY", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 70ee209b..23c7813f 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -3,17 +3,18 @@ # FIXME test websocket connection after RTC is finished still send the full events # FIXME try with locked session, RTC should not work -import pytest +import asyncio import json +import threading +from pathlib import Path +from typing import Union from unittest.mock import patch -from httpx import AsyncClient +import pytest +from httpx import AsyncClient +from httpx_ws import aconnect_ws from reflector.app import app from uvicorn import Config, Server -import threading -import asyncio -from pathlib import Path -from httpx_ws import aconnect_ws class ThreadedUvicorn: @@ -61,7 +62,7 @@ async def dummy_llm(): from reflector.llm.base import LLM class TestLLM(LLM): - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: From 2e48f89fdc0a8e9df82fe36864e3fe53935695e5 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 09:33:59 +0530 Subject: [PATCH 05/14] add comments and log --- server/gpu/modal/reflector_llm.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 2f96e330..21306763 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -113,7 +113,9 @@ class LLM: @method() def generate(self, prompt: str, schema: str = None): print(f"Generate {prompt=}") + # If a schema is given, conform to schema if schema: + print(f"Schema {schema=}") import ast import jsonformer @@ -123,16 +125,17 @@ class LLM: prompt=prompt, max_string_token_length=self.gen_cfg.max_new_tokens) response = jsonformer_llm() - print(f"Generated {response=}") - return {"text": response} + else: + # If no schema, perform prompt only generation - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( + # tokenize prompt + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( self.model.device - ) - output = self.model.generate(input_ids, generation_config=self.gen_cfg) + ) + output = self.model.generate(input_ids, generation_config=self.gen_cfg) - # decode output - response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) + # decode output + response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) print(f"Generated {response=}") return {"text": response} From a24c3afe5bc94ed4f25b48254478006f11fbd7c8 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 09:35:49 +0530 Subject: [PATCH 06/14] cleanup --- server/reflector/processors/transcript_topic_detector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 9ae21a72..3e984741 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -37,7 +37,6 @@ class TranscriptTopicDetectorProcessor(Processor): "summary": {"type": "string"}, }, } - self.kwargs = {"schema": self.topic_detector_schema} async def _warmup(self): await self.llm.warmup(logger=self.logger) From 235ee73f462e6e1a2def5d21168b5b41bfe4a029 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 09:59:16 +0530 Subject: [PATCH 07/14] update prompt --- server/reflector/processors/transcript_topic_detector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 3e984741..3d8e3965 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -14,6 +14,8 @@ class TranscriptTopicDetectorProcessor(Processor): PROMPT = """ ### Human: + Create a JSON object as response.The JSON object must have 2 fields: + i) title and ii) summary. For the title field, generate a short title for the given text. For the summary field, summarize the given text in a maximum of From a98a9853be364890c87ff8320b56f27b788b1e95 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 14:42:45 +0530 Subject: [PATCH 08/14] PR review comments --- .pre-commit-config.yaml | 7 +++++++ server/env.example | 4 ++-- server/reflector/llm/base.py | 5 ++--- server/reflector/llm/llm_banana.py | 3 +-- server/reflector/llm/llm_modal.py | 3 +-- .../reflector/llm/{llm_oobagooda.py => llm_oobabooga.py} | 3 +-- server/reflector/llm/llm_openai.py | 4 +--- server/tests/test_processors_pipeline.py | 5 +---- server/tests/test_transcripts_rtc_ws.py | 3 +-- 9 files changed, 17 insertions(+), 20 deletions(-) rename server/reflector/llm/{llm_oobagooda.py => llm_oobabooga.py} (84%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7132e09c..2a73b7c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,10 @@ repos: hooks: - id: black files: ^server/(reflector|tests)/ + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + files: ^server/(gpu|evaluate|reflector)/ diff --git a/server/env.example b/server/env.example index 11e0927b..2317a1bb 100644 --- a/server/env.example +++ b/server/env.example @@ -45,8 +45,8 @@ ## llm backend implementation ## ======================================================= -## Use oobagooda (default) -#LLM_BACKEND=oobagooda +## Use oobabooga (default) +#LLM_BACKEND=oobabooga #LLM_URL=http://xxx:7860/api/generate/v1 ## Using serverless modal.com (require reflector-gpu-modal deployed) diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index fddf185d..2d83913e 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -2,7 +2,6 @@ import importlib import json import re from time import monotonic -from typing import Union from reflector.logger import logger as reflector_logger from reflector.settings import settings @@ -47,7 +46,7 @@ class LLM: pass async def generate( - self, prompt: str, logger: reflector_logger, schema: str = None, **kwargs + self, prompt: str, logger: reflector_logger, schema: str | None = None, **kwargs ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: @@ -63,7 +62,7 @@ class LLM: return result - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: + async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index 473769cc..e19171bf 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,5 +1,4 @@ import json -from typing import Union import httpx from reflector.llm.base import LLM @@ -16,7 +15,7 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index c1fb856b..7cf7778b 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,5 +1,4 @@ import json -from typing import Union import httpx from reflector.llm.base import LLM @@ -26,7 +25,7 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobabooga.py similarity index 84% rename from server/reflector/llm/llm_oobagooda.py rename to server/reflector/llm/llm_oobabooga.py index 0ceb442d..394f0af4 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,5 +1,4 @@ import json -from typing import Union import httpx from reflector.llm.base import LLM @@ -7,7 +6,7 @@ from reflector.settings import settings class OobaboogaLLM(LLM): - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index 9a74e03c..62cccf7e 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,5 +1,3 @@ -from typing import Union - import httpx from reflector.llm.base import LLM from reflector.logger import logger @@ -17,7 +15,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: + async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 56cac96e..db0a39c5 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -9,16 +9,13 @@ async def test_basic_process(event_loop): from reflector.settings import settings from reflector.llm.base import LLM from pathlib import Path - from typing import Union # use an LLM test backend settings.LLM_BACKEND = "test" settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): - async def _generate( - self, prompt: str, schema: Union[str | None], **kwargs - ) -> str: + async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: return { "title": "TITLE", "summary": "SUMMARY", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 5555d195..943955f8 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -7,7 +7,6 @@ import asyncio import json import threading from pathlib import Path -from typing import Union from unittest.mock import patch import pytest @@ -62,7 +61,7 @@ async def dummy_llm(): from reflector.llm.base import LLM class TestLLM(LLM): - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: From 9103c8cca83de1459447aec5b3971b963264dce8 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 15:15:43 +0530 Subject: [PATCH 09/14] remove ast --- server/gpu/modal/reflector_llm.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 21306763..10cf4772 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -3,11 +3,11 @@ Reflector GPU backend - LLM =========================== """ - +import json import os from typing import Optional -from modal import asgi_app, Image, method, Secret, Stub +from modal import Image, Secret, Stub, asgi_app, method # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" @@ -116,12 +116,11 @@ class LLM: # If a schema is given, conform to schema if schema: print(f"Schema {schema=}") - import ast import jsonformer jsonformer_llm = jsonformer.Jsonformer(model=self.model, tokenizer=self.tokenizer, - json_schema=ast.literal_eval(schema), + json_schema=json.loads(schema), prompt=prompt, max_string_token_length=self.gen_cfg.max_new_tokens) response = jsonformer_llm() @@ -154,7 +153,7 @@ class LLM: ) @asgi_app() def web(): - from fastapi import FastAPI, HTTPException, status, Depends + from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import OAuth2PasswordBearer from pydantic import BaseModel From b08724a191df2964db437e6d1b1a4217e0808064 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 20:57:31 +0530 Subject: [PATCH 10/14] correct schema typing from str to dict --- server/reflector/llm/base.py | 8 ++++++-- server/reflector/llm/llm_banana.py | 2 +- server/reflector/llm/llm_modal.py | 2 +- server/reflector/llm/llm_oobabooga.py | 2 +- server/reflector/llm/llm_openai.py | 2 +- server/tests/test_processors_pipeline.py | 2 +- server/tests/test_transcripts_rtc_ws.py | 2 +- 7 files changed, 12 insertions(+), 8 deletions(-) diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 2d83913e..d046ffe7 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -46,7 +46,11 @@ class LLM: pass async def generate( - self, prompt: str, logger: reflector_logger, schema: str | None = None, **kwargs + self, + prompt: str, + logger: reflector_logger, + schema: dict | None = None, + **kwargs, ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: @@ -62,7 +66,7 @@ class LLM: return result - async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index e19171bf..07119613 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -15,7 +15,7 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 7cf7778b..ea9ff152 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -25,7 +25,7 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_oobabooga.py b/server/reflector/llm/llm_oobabooga.py index 394f0af4..6c5a68ec 100644 --- a/server/reflector/llm/llm_oobabooga.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -6,7 +6,7 @@ from reflector.settings import settings class OobaboogaLLM(LLM): - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index 62cccf7e..7ed532b7 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -15,7 +15,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index db0a39c5..cc6a8574 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -15,7 +15,7 @@ async def test_basic_process(event_loop): settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): - async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: return { "title": "TITLE", "summary": "SUMMARY", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 943955f8..09f3d7e5 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -61,7 +61,7 @@ async def dummy_llm(): from reflector.llm.base import LLM class TestLLM(LLM): - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: From 2d686da15c98e83731602dbc878e67633baf0316 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 21:26:20 +0530 Subject: [PATCH 11/14] pass schema as dict --- server/gpu/modal/reflector_llm.py | 7 +++++-- server/reflector/llm/llm_banana.py | 4 +--- server/reflector/llm/llm_modal.py | 4 +--- server/reflector/llm/llm_oobabooga.py | 4 +--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index 10cf4772..fd8a4aae 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -172,13 +172,16 @@ def web(): class LLMRequest(BaseModel): prompt: str - schema: Optional[str] = None + schema: Optional[dict] = None @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - func = llmstub.generate.spawn(prompt=req.prompt, schema=req.schema) + if req.schema: + func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema)) + else: + func = llmstub.generate.spawn(prompt=req.prompt) result = func.get() return result diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index 07119613..56fc0e69 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -18,7 +16,7 @@ class BananaLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index ea9ff152..ce0de02a 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -28,7 +26,7 @@ class ModalLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, diff --git a/server/reflector/llm/llm_oobabooga.py b/server/reflector/llm/llm_oobabooga.py index 6c5a68ec..411014c5 100644 --- a/server/reflector/llm/llm_oobabooga.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,5 +1,3 @@ -import json - import httpx from reflector.llm.base import LLM from reflector.settings import settings @@ -9,7 +7,7 @@ class OobaboogaLLM(LLM): async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: - json_payload["schema"] = json.dumps(schema) + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, From 7809b60011eeae819df805f9219e40dd93f739a2 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 18 Aug 2023 10:08:27 +0200 Subject: [PATCH 12/14] server: remove print() statements --- server/reflector/views/transcripts.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index bfe61473..29db9ec7 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -136,9 +136,7 @@ class TranscriptController: query = transcripts.select() if user_id is not None: query = query.where(transcripts.c.user_id == user_id) - print(query) results = await database.fetch_all(query) - print(results) return results async def get_by_id( From 5c9adb2664e1f053bf6f6e0597c858af745ae6e9 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 18 Aug 2023 10:23:15 +0200 Subject: [PATCH 13/14] server: fixes tests --- server/tests/test_transcripts_rtc_ws.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 28425cd5..8237d4ab 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -12,7 +12,6 @@ from unittest.mock import patch import pytest from httpx import AsyncClient from httpx_ws import aconnect_ws -from reflector.app import app from uvicorn import Config, Server From 2a3ad5657f75fa9de0bf57928c524f4f7a0a055c Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 18 Aug 2023 12:02:16 +0200 Subject: [PATCH 14/14] server: add /v1/me to get current user information sub, email and email_verified --- server/reflector/app.py | 17 ++++++++++------- server/reflector/views/user.py | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 server/reflector/views/user.py diff --git a/server/reflector/app.py b/server/reflector/app.py index fa148240..14d55d7a 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -1,15 +1,17 @@ +from contextlib import asynccontextmanager + +import reflector.auth # noqa +import reflector.db # noqa from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi_pagination import add_pagination from fastapi.routing import APIRoute -import reflector.db # noqa -import reflector.auth # noqa -from reflector.views.rtc_offer import router as rtc_offer_router -from reflector.views.transcripts import router as transcripts_router -from reflector.events import subscribers_startup, subscribers_shutdown +from fastapi_pagination import add_pagination +from reflector.events import subscribers_shutdown, subscribers_startup from reflector.logger import logger from reflector.settings import settings -from contextlib import asynccontextmanager +from reflector.views.rtc_offer import router as rtc_offer_router +from reflector.views.transcripts import router as transcripts_router +from reflector.views.user import router as user_router try: import sentry_sdk @@ -50,6 +52,7 @@ app.add_middleware( # register views app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") +app.include_router(user_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/views/user.py b/server/reflector/views/user.py new file mode 100644 index 00000000..4952c471 --- /dev/null +++ b/server/reflector/views/user.py @@ -0,0 +1,20 @@ +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +router = APIRouter() + + +class UserInfo(BaseModel): + sub: str + email: Optional[str] + email_verified: Optional[bool] + + +@router.get("/me") +async def user_me( + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> UserInfo | None: + return user