mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
* dailyco api module (no-mistakes) * daily co library self-review * uncurse * self-review: daily resource leak, uniform types, enable_recording bomb, daily custom error, video_platforms/daily typing, daily timestamp dry * dailyco docs parser * phase 1-2 of daily poll * dailyco poll (no-mistakes) * poll docs * fix tests * forgotten utils file * remove generated daily docs * pr comments * dailyco poll pr review and self-review * daily recording poll api fix * daily recording poll api fix * review * review * fix tests --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import logging
|
|
from typing import Type, TypeVar
|
|
|
|
from llama_index.core import Settings
|
|
from llama_index.core.output_parsers import PydanticOutputParser
|
|
from llama_index.core.program import LLMTextCompletionProgram
|
|
from llama_index.core.response_synthesizers import TreeSummarize
|
|
from llama_index.llms.openai_like import OpenAILike
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
T = TypeVar("T", bound=BaseModel)
|
|
|
|
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
|
|
Based on the following analysis, provide the information in the requested JSON format:
|
|
|
|
Analysis:
|
|
{analysis}
|
|
|
|
{format_instructions}
|
|
"""
|
|
|
|
|
|
class LLM:
|
|
def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
|
|
self.settings_obj = settings
|
|
self.model_name = settings.LLM_MODEL
|
|
self.url = settings.LLM_URL
|
|
self.api_key = settings.LLM_API_KEY
|
|
self.context_window = settings.LLM_CONTEXT_WINDOW
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
|
|
# Configure llamaindex Settings
|
|
self._configure_llamaindex()
|
|
|
|
def _configure_llamaindex(self):
|
|
"""Configure llamaindex Settings with OpenAILike LLM"""
|
|
Settings.llm = OpenAILike(
|
|
model=self.model_name,
|
|
api_base=self.url,
|
|
api_key=self.api_key,
|
|
context_window=self.context_window,
|
|
is_chat_model=True,
|
|
is_function_calling_model=False,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
)
|
|
|
|
async def get_response(
|
|
self, prompt: str, texts: list[str], tone_name: str | None = None
|
|
) -> str:
|
|
"""Get a text response using TreeSummarize for non-function-calling models"""
|
|
summarizer = TreeSummarize(verbose=False)
|
|
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
|
return str(response).strip()
|
|
|
|
async def get_structured_response(
|
|
self,
|
|
prompt: str,
|
|
texts: list[str],
|
|
output_cls: Type[T],
|
|
tone_name: str | None = None,
|
|
) -> T:
|
|
"""Get structured output from LLM for non-function-calling models"""
|
|
logger = logging.getLogger(__name__)
|
|
|
|
summarizer = TreeSummarize(verbose=True)
|
|
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
|
|
|
|
output_parser = PydanticOutputParser(output_cls)
|
|
|
|
program = LLMTextCompletionProgram.from_defaults(
|
|
output_parser=output_parser,
|
|
prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
|
|
verbose=False,
|
|
)
|
|
|
|
format_instructions = output_parser.format(
|
|
"Please structure the above information in the following JSON format:"
|
|
)
|
|
|
|
try:
|
|
output = await program.acall(
|
|
analysis=str(response), format_instructions=format_instructions
|
|
)
|
|
except ValidationError as e:
|
|
# Extract the raw JSON from the error details
|
|
errors = e.errors()
|
|
if errors and "input" in errors[0]:
|
|
raw_json = errors[0]["input"]
|
|
logger.error(
|
|
f"JSON validation failed for {output_cls.__name__}. "
|
|
f"Full raw JSON output:\n{raw_json}\n"
|
|
f"Validation errors: {errors}"
|
|
)
|
|
else:
|
|
logger.error(
|
|
f"JSON validation failed for {output_cls.__name__}. "
|
|
f"Validation errors: {errors}"
|
|
)
|
|
raise
|
|
|
|
return output
|