Files
reflector/server/reflector/llm.py
Sergey Mankovsky c155f66982 fix: improve hatchet workflow reliability (#900)
* Increase max connections

* Classify hard and transient hatchet errors

* Fan out partial success

* Force reprocessing of error transcripts

* Stop retrying on 402 payment required

* Avoid httpx/hatchet timeout race

* Add retry wrapper to get_response for for transient errors

* Add retry backoff

* Return falsy results so get_response won't retry on empty string

* Skip error status in on_workflow_failure when transcript already ended

* Fix precommit issues

* Fail step on first fan-out failure instead of skipping
2026-03-06 17:07:26 +01:00

184 lines
6.8 KiB
Python

import logging
from contextvars import ContextVar
from typing import Type, TypeVar
from uuid import uuid4
from llama_index.core import Settings
from llama_index.core.prompts import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.llms.openai_like import OpenAILike
from pydantic import BaseModel, ValidationError
from reflector.utils.retry import retry
T = TypeVar("T", bound=BaseModel)
# Session ID for LiteLLM request grouping - set per processing run
llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None)
logger = logging.getLogger(__name__)
class LLMParseError(Exception):
"""Raised when LLM output cannot be parsed after retries."""
def __init__(self, output_cls: Type[BaseModel], error_msg: str, attempts: int):
self.output_cls = output_cls
self.error_msg = error_msg
self.attempts = attempts
super().__init__(
f"Failed to parse {output_cls.__name__} after {attempts} attempts: {error_msg}"
)
class LLM:
def __init__(
self, settings, temperature: float = 0.4, max_tokens: int | None = None
):
self.settings_obj = settings
self.model_name = settings.LLM_MODEL
self.url = settings.LLM_URL
self.api_key = settings.LLM_API_KEY
self.context_window = settings.LLM_CONTEXT_WINDOW
self.temperature = temperature
self.max_tokens = max_tokens
self._configure_llamaindex()
def _configure_llamaindex(self):
"""Configure llamaindex Settings with OpenAILike LLM"""
session_id = llm_session_id.get() or f"fallback-{uuid4().hex}"
Settings.llm = OpenAILike(
model=self.model_name,
api_base=self.url,
api_key=self.api_key,
context_window=self.context_window,
is_chat_model=True,
is_function_calling_model=True,
temperature=self.temperature,
max_tokens=self.max_tokens,
timeout=self.settings_obj.LLM_REQUEST_TIMEOUT,
additional_kwargs={"extra_body": {"litellm_session_id": session_id}},
)
async def get_response(
self, prompt: str, texts: list[str], tone_name: str | None = None
) -> str:
"""Get a text response using TreeSummarize for non-function-calling models.
Uses the same retry() wrapper as get_structured_response for transient
network errors (connection, timeout, OSError) with exponential backoff.
"""
async def _call():
summarizer = TreeSummarize(verbose=False)
response = await summarizer.aget_response(
prompt, texts, tone_name=tone_name
)
return str(response).strip()
return await retry(_call)(
retry_attempts=3,
retry_backoff_interval=1.0,
retry_backoff_max=30.0,
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
)
async def get_structured_response(
self,
prompt: str,
texts: list[str],
output_cls: Type[T],
tone_name: str | None = None,
timeout: int | None = None,
) -> T:
"""Get structured output from LLM using tool-call with reflection retry.
Uses astructured_predict (function-calling / tool-call mode) for the
first attempt. On ValidationError or parse failure the wrong output
and error are fed back as a reflection prompt and the call is retried
up to LLM_PARSE_MAX_RETRIES times.
The outer retry() wrapper handles transient network errors with
exponential back-off.
"""
max_retries = self.settings_obj.LLM_PARSE_MAX_RETRIES
async def _call_with_reflection():
# Build full prompt: instruction + source texts
if texts:
texts_block = "\n\n".join(texts)
full_prompt = f"{prompt}\n\n{texts_block}"
else:
full_prompt = prompt
prompt_tmpl = PromptTemplate("{user_prompt}")
last_error: str | None = None
for attempt in range(1, max_retries + 2): # +2: first try + retries
try:
if attempt == 1:
result = await Settings.llm.astructured_predict(
output_cls, prompt_tmpl, user_prompt=full_prompt
)
else:
reflection_tmpl = PromptTemplate(
"{user_prompt}\n\n{reflection}"
)
result = await Settings.llm.astructured_predict(
output_cls,
reflection_tmpl,
user_prompt=full_prompt,
reflection=reflection,
)
if attempt > 1:
logger.info(
f"LLM structured_predict succeeded on attempt "
f"{attempt}/{max_retries + 1} for {output_cls.__name__}"
)
return result
except (ValidationError, ValueError) as e:
wrong_output = str(e)
if len(wrong_output) > 2000:
wrong_output = wrong_output[:2000] + "... [truncated]"
last_error = self._format_validation_error(e)
reflection = (
f"Your previous response could not be parsed.\n\n"
f"Error:\n{last_error}\n\n"
"Please try again and return valid data matching the schema."
)
logger.error(
f"LLM parse error (attempt {attempt}/{max_retries + 1}): "
f"{type(e).__name__}: {e}\n"
f"Raw response: {wrong_output[:500]}"
)
raise LLMParseError(
output_cls=output_cls,
error_msg=last_error or "Max retries exceeded",
attempts=max_retries + 1,
)
return await retry(_call_with_reflection)(
retry_attempts=3,
retry_backoff_interval=1.0,
retry_backoff_max=30.0,
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
)
@staticmethod
def _format_validation_error(error: Exception) -> str:
"""Format a validation/parse error for LLM reflection feedback."""
if isinstance(error, ValidationError):
error_messages = []
for err in error.errors():
field = ".".join(str(loc) for loc in err["loc"])
error_messages.append(f"- {err['msg']} in field '{field}'")
return "Schema validation errors:\n" + "\n".join(error_messages)
return f"Parse error: {str(error)}"