diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 2d83913e..d046ffe7 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -46,7 +46,11 @@ class LLM: pass async def generate( - self, prompt: str, logger: reflector_logger, schema: str | None = None, **kwargs + self, + prompt: str, + logger: reflector_logger, + schema: dict | None = None, + **kwargs, ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: @@ -62,7 +66,7 @@ class LLM: return result - async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index e19171bf..07119613 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -15,7 +15,7 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 7cf7778b..ea9ff152 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -25,7 +25,7 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_oobabooga.py b/server/reflector/llm/llm_oobabooga.py index 394f0af4..6c5a68ec 100644 --- a/server/reflector/llm/llm_oobabooga.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -6,7 +6,7 @@ from reflector.settings import settings class OobaboogaLLM(LLM): - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): json_payload = {"prompt": prompt} if schema: json_payload["schema"] = json.dumps(schema) diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index 62cccf7e..7ed532b7 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -15,7 +15,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index db0a39c5..cc6a8574 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -15,7 +15,7 @@ async def test_basic_process(event_loop): settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): - async def _generate(self, prompt: str, schema: str | None, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: return { "title": "TITLE", "summary": "SUMMARY", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 943955f8..09f3d7e5 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -61,7 +61,7 @@ async def dummy_llm(): from reflector.llm.base import LLM class TestLLM(LLM): - async def _generate(self, prompt: str, schema: str | None, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: