make schema optional argument

This commit is contained in:
Gokul Mohanarangan
2023-08-17 09:23:14 +05:30
parent 5f79e04642
commit eb13a7bd64
9 changed files with 48 additions and 37 deletions

View File

@@ -5,9 +5,9 @@ Reflector GPU backend - LLM
"""
import os
from typing import Optional
from modal import asgi_app, Image, method, Secret, Stub
from pydantic.typing import Optional
# LLM
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"

View File

@@ -1,10 +1,12 @@
from reflector.settings import settings
from reflector.utils.retry import retry
from reflector.logger import logger as reflector_logger
from time import monotonic
import importlib
import json
import re
from time import monotonic
from typing import Union
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
class LLM:
@@ -44,10 +46,12 @@ class LLM:
async def _warmup(self, logger: reflector_logger):
pass
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict:
async def generate(
self, prompt: str, logger: reflector_logger, schema: str = None, **kwargs
) -> dict:
logger.info("LLM generate", prompt=repr(prompt))
try:
result = await retry(self._generate)(prompt=prompt, **kwargs)
result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs)
except Exception:
logger.exception("Failed to call llm after retrying")
raise
@@ -59,7 +63,7 @@ class LLM:
return result
async def _generate(self, prompt: str, **kwargs) -> str:
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str:
raise NotImplementedError
def _parse_json(self, result: str) -> dict:

View File

@@ -1,4 +1,5 @@
import json
from typing import Union
import httpx
from reflector.llm.base import LLM
@@ -15,10 +16,10 @@ class BananaLLM(LLM):
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
}
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.LLM_URL,
@@ -29,7 +30,7 @@ class BananaLLM(LLM):
)
response.raise_for_status()
text = response.json()["text"]
if "schema" not in json_payload:
if not schema:
text = text[len(prompt) :]
return text

View File

@@ -1,4 +1,5 @@
import json
from typing import Union
import httpx
from reflector.llm.base import LLM
@@ -25,10 +26,10 @@ class ModalLLM(LLM):
)
response.raise_for_status()
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
@@ -39,7 +40,7 @@ class ModalLLM(LLM):
)
response.raise_for_status()
text = response.json()["text"]
if "schema" not in json_payload:
if not schema:
text = text[len(prompt) :]
return text
@@ -54,13 +55,12 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger)
print(result)
kwargs = {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
}
result = await llm.generate("Hello, my name is", kwargs=kwargs, logger=logger)
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
print(result)
import asyncio

View File

@@ -1,4 +1,5 @@
import json
from typing import Union
import httpx
from reflector.llm.base import LLM
@@ -6,10 +7,10 @@ from reflector.settings import settings
class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
json_payload = {"prompt": prompt}
if "schema" in kwargs:
json_payload["schema"] = json.dumps(kwargs["schema"])
if schema:
json_payload["schema"] = json.dumps(schema)
async with httpx.AsyncClient() as client:
response = await client.post(
settings.LLM_URL,

View File

@@ -1,7 +1,9 @@
from typing import Union
import httpx
from reflector.llm.base import LLM
from reflector.logger import logger
from reflector.settings import settings
import httpx
class OpenAILLM(LLM):
@@ -15,7 +17,7 @@ class OpenAILLM(LLM):
self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(self, prompt: str, **kwargs) -> str:
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}",

View File

@@ -14,7 +14,6 @@ class TranscriptTopicDetectorProcessor(Processor):
PROMPT = """
### Human:
Generate information based on the given schema:
For the title field, generate a short title for the given text.
For the summary field, summarize the given text in a maximum of
@@ -62,7 +61,7 @@ class TranscriptTopicDetectorProcessor(Processor):
self.logger.info(f"Topic detector got {len(text)} length transcript")
prompt = self.PROMPT.format(input_text=text)
result = await retry(self.llm.generate)(
prompt=prompt, kwargs=self.kwargs, logger=self.logger
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
)
summary = TitleSummary(
title=result["title"],

View File

@@ -9,13 +9,16 @@ async def test_basic_process(event_loop):
from reflector.settings import settings
from reflector.llm.base import LLM
from pathlib import Path
from typing import Union
# use an LLM test backend
settings.LLM_BACKEND = "test"
settings.TRANSCRIPT_BACKEND = "whisper"
class LLMTest(LLM):
async def _generate(self, prompt: str, **kwargs) -> str:
async def _generate(
self, prompt: str, schema: Union[str | None], **kwargs
) -> str:
return {
"title": "TITLE",
"summary": "SUMMARY",

View File

@@ -3,17 +3,18 @@
# FIXME test websocket connection after RTC is finished still send the full events
# FIXME try with locked session, RTC should not work
import pytest
import asyncio
import json
import threading
from pathlib import Path
from typing import Union
from unittest.mock import patch
from httpx import AsyncClient
import pytest
from httpx import AsyncClient
from httpx_ws import aconnect_ws
from reflector.app import app
from uvicorn import Config, Server
import threading
import asyncio
from pathlib import Path
from httpx_ws import aconnect_ws
class ThreadedUvicorn:
@@ -61,7 +62,7 @@ async def dummy_llm():
from reflector.llm.base import LLM
class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
async def _generate(self, prompt: str, schema: Union[str | None], **kwargs):
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm: