From a98a9853be364890c87ff8320b56f27b788b1e95 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Thu, 17 Aug 2023 14:42:45 +0530 Subject: [PATCH] PR review comments --- .pre-commit-config.yaml | 7 +++++++ server/env.example | 4 ++-- server/reflector/llm/base.py | 5 ++--- server/reflector/llm/llm_banana.py | 3 +-- server/reflector/llm/llm_modal.py | 3 +-- .../reflector/llm/{llm_oobagooda.py => llm_oobabooga.py} | 3 +-- server/reflector/llm/llm_openai.py | 4 +--- server/tests/test_processors_pipeline.py | 5 +---- server/tests/test_transcripts_rtc_ws.py | 3 +-- 9 files changed, 17 insertions(+), 20 deletions(-) rename server/reflector/llm/{llm_oobagooda.py => llm_oobabooga.py} (84%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7132e09c..2a73b7c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,10 @@ repos: hooks: - id: black files: ^server/(reflector|tests)/ + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + files: ^server/(gpu|evaluate|reflector)/ diff --git a/server/env.example b/server/env.example index 11e0927b..2317a1bb 100644 --- a/server/env.example +++ b/server/env.example @@ -45,8 +45,8 @@ ## llm backend implementation ## ======================================================= -## Use oobagooda (default) -#LLM_BACKEND=oobagooda +## Use oobabooga (default) +#LLM_BACKEND=oobabooga #LLM_URL=http://xxx:7860/api/generate/v1 ## Using serverless modal.com (require reflector-gpu-modal deployed) diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index fddf185d..2d83913e 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -2,7 +2,6 @@ import importlib import json import re from time import monotonic -from typing import Union from reflector.logger import logger as reflector_logger from reflector.settings import settings @@ -47,7 +46,7 @@ class LLM: pass async def generate( - self, prompt: str, logger: reflector_logger, schema: str = None, **kwargs + self, prompt: str, logger: reflector_logger, schema: str | None = None, **kwargs ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: @@ -63,7 +62,7 @@ class LLM: return result - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: + async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index 473769cc..e19171bf 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,5 +1,4 @@ import json -from typing import Union import httpx from reflector.llm.base import LLM @@ -16,7 +15,7 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index c1fb856b..7cf7778b 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,5 +1,4 @@ import json -from typing import Union import httpx from reflector.llm.base import LLM @@ -26,7 +25,7 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobabooga.py similarity index 84% rename from server/reflector/llm/llm_oobagooda.py rename to server/reflector/llm/llm_oobabooga.py index 0ceb442d..394f0af4 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,5 +1,4 @@ import json -from typing import Union import httpx from reflector.llm.base import LLM @@ -7,7 +6,7 @@ from reflector.settings import settings class OobaboogaLLM(LLM): - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index 9a74e03c..62cccf7e 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,5 +1,3 @@ -from typing import Union - import httpx from reflector.llm.base import LLM from reflector.logger import logger @@ -17,7 +15,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: + async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 56cac96e..db0a39c5 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -9,16 +9,13 @@ async def test_basic_process(event_loop): from reflector.settings import settings from reflector.llm.base import LLM from pathlib import Path - from typing import Union # use an LLM test backend settings.LLM_BACKEND = "test" settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): - async def _generate( - self, prompt: str, schema: Union[str | None], **kwargs - ) -> str: + async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: return { "title": "TITLE", "summary": "SUMMARY", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 5555d195..943955f8 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -7,7 +7,6 @@ import asyncio import json import threading from pathlib import Path -from typing import Union from unittest.mock import patch import pytest @@ -62,7 +61,7 @@ async def dummy_llm(): from reflector.llm.base import LLM class TestLLM(LLM): - async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): + async def _generate(self, prompt: str, schema: str | None, **kwargs): return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: