From eb13a7bd64e83e5cadbd2151a34af3d733aaebd0 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 09:23:14 +0530 Subject: [PATCH] make schema optional argument --- server/gpu/modal/reflector_llm.py | 2 +- server/reflector/llm/base.py | 18 ++++++++++------- server/reflector/llm/llm_banana.py | 9 +++++---- server/reflector/llm/llm_modal.py | 20 +++++++++---------- server/reflector/llm/llm_oobagooda.py | 7 ++++--- server/reflector/llm/llm_openai.py | 6 ++++-- .../processors/transcript_topic_detector.py | 3 +-- server/tests/test_processors_pipeline.py | 5 ++++- server/tests/test_transcripts_rtc_ws.py | 15 +++++++------- 9 files changed, 48 insertions(+), 37 deletions(-) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index d83d5036..2f96e330 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -5,9 +5,9 @@ Reflector GPU backend - LLM """ import os +from typing import Optional from modal import asgi_app, Image, method, Secret, Stub -from pydantic.typing import Optional # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 5e86b553..fddf185d 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -1,10 +1,12 @@ -from reflector.settings import settings -from reflector.utils.retry import retry -from reflector.logger import logger as reflector_logger -from time import monotonic import importlib import json import re +from time import monotonic +from typing import Union + +from reflector.logger import logger as reflector_logger +from reflector.settings import settings +from reflector.utils.retry import retry class LLM: @@ -44,10 +46,12 @@ class LLM: async def _warmup(self, logger: reflector_logger): pass - async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict: + async def generate( + self, prompt: str, logger: reflector_logger, schema: str = None, **kwargs + ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: - result = await retry(self._generate)(prompt=prompt, **kwargs) + result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs) except Exception: logger.exception("Failed to call llm after retrying") raise @@ -59,7 +63,7 @@ class LLM: return result - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index bdaf091f..473769cc 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,4 +1,5 @@ import json +from typing import Union import httpx from reflector.llm.base import LLM @@ -15,10 +16,10 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): json_payload = {"prompt": prompt} - if "schema" in kwargs: - json_payload["schema"] = json.dumps(kwargs["schema"]) + if schema: + json_payload["schema"] = json.dumps(schema) async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, @@ -29,7 +30,7 @@ class BananaLLM(LLM): ) response.raise_for_status() text = response.json()["text"] - if "schema" not in json_payload: + if not schema: text = text[len(prompt) :] return text diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 692dd095..c1fb856b 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,4 +1,5 @@ import json +from typing import Union import httpx from reflector.llm.base import LLM @@ -25,10 +26,10 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): json_payload = {"prompt": prompt} - if "schema" in kwargs: - json_payload["schema"] = json.dumps(kwargs["schema"]) + if schema: + json_payload["schema"] = json.dumps(schema) async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, @@ -39,7 +40,7 @@ class ModalLLM(LLM): ) response.raise_for_status() text = response.json()["text"] - if "schema" not in json_payload: + if not schema: text = text[len(prompt) :] return text @@ -54,13 +55,12 @@ if __name__ == "__main__": result = await llm.generate("Hello, my name is", logger=logger) print(result) - kwargs = { - "schema": { - "type": "object", - "properties": {"name": {"type": "string"}}, - } + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, } - result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger) + + result = await llm.generate("Hello, my name is", schema=schema, logger=logger) print(result) import asyncio diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobagooda.py index 85306135..0ceb442d 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobagooda.py @@ -1,4 +1,5 @@ import json +from typing import Union import httpx from reflector.llm.base import LLM @@ -6,10 +7,10 @@ from reflector.settings import settings class OobaboogaLLM(LLM): - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): json_payload = {"prompt": prompt} - if "schema" in kwargs: - json_payload["schema"] = json.dumps(kwargs["schema"]) + if schema: + json_payload["schema"] = json.dumps(schema) async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index dd438704..9a74e03c 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,7 +1,9 @@ +from typing import Union + +import httpx from reflector.llm.base import LLM from reflector.logger import logger from reflector.settings import settings -import httpx class OpenAILLM(LLM): @@ -15,7 +17,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 430e3992..9ae21a72 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -14,7 +14,6 @@ class TranscriptTopicDetectorProcessor(Processor): PROMPT = """ ### Human: - Generate information based on the given schema: For the title field, generate a short title for the given text. For the summary field, summarize the given text in a maximum of @@ -62,7 +61,7 @@ class TranscriptTopicDetectorProcessor(Processor): self.logger.info(f"Topic detector got {len(text)} length transcript") prompt = self.PROMPT.format(input_text=text) result = await retry(self.llm.generate)( - prompt=prompt, kwargs=self.kwargs, logger=self.logger + prompt=prompt, schema=self.topic_detector_schema, logger=self.logger ) summary = TitleSummary( title=result["title"], diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 95c296de..56cac96e 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -9,13 +9,16 @@ async def test_basic_process(event_loop): from reflector.settings import settings from reflector.llm.base import LLM from pathlib import Path + from typing import Union # use an LLM test backend settings.LLM_BACKEND = "test" settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate( + self, prompt: str, schema: Union[str | None], **kwargs + ) -> str: return { "title": "TITLE", "summary": "SUMMARY", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 70ee209b..23c7813f 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -3,17 +3,18 @@ # FIXME test websocket connection after RTC is finished still send the full events # FIXME try with locked session, RTC should not work -import pytest +import asyncio import json +import threading +from pathlib import Path +from typing import Union from unittest.mock import patch -from httpx import AsyncClient +import pytest +from httpx import AsyncClient +from httpx_ws import aconnect_ws from reflector.app import app from uvicorn import Config, Server -import threading -import asyncio -from pathlib import Path -from httpx_ws import aconnect_ws class ThreadedUvicorn: @@ -61,7 +62,7 @@ async def dummy_llm(): from reflector.llm.base import LLM class TestLLM(LLM): - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: