PR review comments

This commit is contained in:
Gokul Mohanarangan
2023-08-17 14:42:45 +05:30
parent 235ee73f46
commit a98a9853be
9 changed files with 17 additions and 20 deletions

View File

@@ -29,3 +29,10 @@ repos:
hooks: hooks:
- id: black - id: black
files: ^server/(reflector|tests)/ files: ^server/(reflector|tests)/
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
files: ^server/(gpu|evaluate|reflector)/

View File

@@ -45,8 +45,8 @@
## llm backend implementation ## llm backend implementation
## ======================================================= ## =======================================================
## Use oobagooda (default) ## Use oobabooga (default)
#LLM_BACKEND=oobagooda #LLM_BACKEND=oobabooga
#LLM_URL=http://xxx:7860/api/generate/v1 #LLM_URL=http://xxx:7860/api/generate/v1
## Using serverless modal.com (require reflector-gpu-modal deployed) ## Using serverless modal.com (require reflector-gpu-modal deployed)

View File

@@ -2,7 +2,6 @@ import importlib
import json import json
import re import re
from time import monotonic from time import monotonic
from typing import Union
from reflector.logger import logger as reflector_logger from reflector.logger import logger as reflector_logger
from reflector.settings import settings from reflector.settings import settings
@@ -47,7 +46,7 @@ class LLM:
pass pass
async def generate( async def generate(
self, prompt: str, logger: reflector_logger, schema: str = None, **kwargs self, prompt: str, logger: reflector_logger, schema: str | None = None, **kwargs
) -> dict: ) -> dict:
logger.info("LLM generate", prompt=repr(prompt)) logger.info("LLM generate", prompt=repr(prompt))
try: try:
@@ -63,7 +62,7 @@ class LLM:
return result return result
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str:
raise NotImplementedError raise NotImplementedError
def _parse_json(self, result: str) -> dict: def _parse_json(self, result: str) -> dict:

View File

@@ -1,5 +1,4 @@
import json import json
from typing import Union
import httpx import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
@@ -16,7 +15,7 @@ class BananaLLM(LLM):
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
} }
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): async def _generate(self, prompt: str, schema: str | None, **kwargs):
json_payload = {"prompt": prompt} json_payload = {"prompt": prompt}
if schema: if schema:
json_payload["schema"] = json.dumps(schema) json_payload["schema"] = json.dumps(schema)

View File

@@ -1,5 +1,4 @@
import json import json
from typing import Union
import httpx import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
@@ -26,7 +25,7 @@ class ModalLLM(LLM):
) )
response.raise_for_status() response.raise_for_status()
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): async def _generate(self, prompt: str, schema: str | None, **kwargs):
json_payload = {"prompt": prompt} json_payload = {"prompt": prompt}
if schema: if schema:
json_payload["schema"] = json.dumps(schema) json_payload["schema"] = json.dumps(schema)

View File

@@ -1,5 +1,4 @@
import json import json
from typing import Union
import httpx import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
@@ -7,7 +6,7 @@ from reflector.settings import settings
class OobaboogaLLM(LLM): class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): async def _generate(self, prompt: str, schema: str | None, **kwargs):
json_payload = {"prompt": prompt} json_payload = {"prompt": prompt}
if schema: if schema:
json_payload["schema"] = json.dumps(schema) json_payload["schema"] = json.dumps(schema)

View File

@@ -1,5 +1,3 @@
from typing import Union
import httpx import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.logger import logger from reflector.logger import logger
@@ -17,7 +15,7 @@ class OpenAILLM(LLM):
self.max_tokens = settings.LLM_MAX_TOKENS self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}") logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str: async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str:
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}", "Authorization": f"Bearer {self.openai_key}",

View File

@@ -9,16 +9,13 @@ async def test_basic_process(event_loop):
from reflector.settings import settings from reflector.settings import settings
from reflector.llm.base import LLM from reflector.llm.base import LLM
from pathlib import Path from pathlib import Path
from typing import Union
# use an LLM test backend # use an LLM test backend
settings.LLM_BACKEND = "test" settings.LLM_BACKEND = "test"
settings.TRANSCRIPT_BACKEND = "whisper" settings.TRANSCRIPT_BACKEND = "whisper"
class LLMTest(LLM): class LLMTest(LLM):
async def _generate( async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str:
self, prompt: str, schema: Union[str | None], **kwargs
) -> str:
return { return {
"title": "TITLE", "title": "TITLE",
"summary": "SUMMARY", "summary": "SUMMARY",

View File

@@ -7,7 +7,6 @@ import asyncio
import json import json
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Union
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -62,7 +61,7 @@ async def dummy_llm():
from reflector.llm.base import LLM from reflector.llm.base import LLM
class TestLLM(LLM): class TestLLM(LLM):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs): async def _generate(self, prompt: str, schema: str | None, **kwargs):
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm: with patch("reflector.llm.base.LLM.get_instance") as mock_llm: