mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-23 21:55:19 +00:00
fix: switch structured output to tool-call with reflection retry (#879)
* fix: switch structured output to tool-call with reflection retry Replace the two-pass StructuredOutputWorkflow (TreeSummarize → acomplete) with astructured_predict + reflection retry loop for structured LLM output. - Enable function-calling mode (is_function_calling_model=True) - Use astructured_predict with PromptTemplate for first attempt - On ValidationError/parse failure, retry with reflection feedback - Add min_length=10 to TopicResponse title/summary fields - Remove dead StructuredOutputWorkflow class and its event types - Rewrite tests to match new astructured_predict approach * fix: include texts parameter in astructured_predict prompt The switch to astructured_predict dropped the texts parameter entirely, causing summary prompts (participants, subjects, action items) to be sent without the transcript content. Combine texts with the prompt before calling astructured_predict, mirroring what TreeSummarize did. * fix: reduce TopicResponse min_length from 10 to 8 for title and summary * ci: try fixing spawning job in github * ci: fix for new arm64 builder
This commit is contained in:
@@ -1,42 +1,23 @@
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Generic, Type, TypeVar
|
||||
from typing import Type, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.output_parsers import PydanticOutputParser
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.response_synthesizers import TreeSummarize
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
StartEvent,
|
||||
StopEvent,
|
||||
Workflow,
|
||||
step,
|
||||
)
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from workflows.errors import WorkflowTimeoutError
|
||||
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
OutputT = TypeVar("OutputT", bound=BaseModel)
|
||||
|
||||
# Session ID for LiteLLM request grouping - set per processing run
|
||||
llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
|
||||
Based on the following analysis, provide the information in the requested JSON format:
|
||||
|
||||
Analysis:
|
||||
{analysis}
|
||||
|
||||
{format_instructions}
|
||||
"""
|
||||
|
||||
|
||||
class LLMParseError(Exception):
|
||||
"""Raised when LLM output cannot be parsed after retries."""
|
||||
@@ -50,157 +31,6 @@ class LLMParseError(Exception):
|
||||
)
|
||||
|
||||
|
||||
class ExtractionDone(Event):
|
||||
"""Event emitted when LLM JSON formatting completes."""
|
||||
|
||||
output: str
|
||||
|
||||
|
||||
class ValidationErrorEvent(Event):
|
||||
"""Event emitted when validation fails."""
|
||||
|
||||
error: str
|
||||
wrong_output: str
|
||||
|
||||
|
||||
class StructuredOutputWorkflow(Workflow, Generic[OutputT]):
|
||||
"""Workflow for structured output extraction with validation retry.
|
||||
|
||||
This workflow handles parse/validation retries only. Network error retries
|
||||
are handled internally by Settings.llm (OpenAILike max_retries=3).
|
||||
The caller should NOT wrap this workflow in additional retry logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_cls: Type[OutputT],
|
||||
max_retries: int = 3,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.output_cls: Type[OutputT] = output_cls
|
||||
self.max_retries = max_retries
|
||||
self.output_parser = PydanticOutputParser(output_cls)
|
||||
|
||||
@step
|
||||
async def extract(
|
||||
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
|
||||
) -> StopEvent | ExtractionDone:
|
||||
"""Extract structured data from text using two-step LLM process.
|
||||
|
||||
Step 1 (first call only): TreeSummarize generates text analysis
|
||||
Step 2 (every call): Settings.llm.acomplete formats analysis as JSON
|
||||
"""
|
||||
current_retries = await ctx.store.get("retries", default=0)
|
||||
await ctx.store.set("retries", current_retries + 1)
|
||||
|
||||
if current_retries >= self.max_retries:
|
||||
last_error = await ctx.store.get("last_error", default=None)
|
||||
logger.error(
|
||||
f"Max retries ({self.max_retries}) reached for {self.output_cls.__name__}"
|
||||
)
|
||||
return StopEvent(result={"error": last_error, "attempts": current_retries})
|
||||
|
||||
if isinstance(ev, StartEvent):
|
||||
# First call: run TreeSummarize to get analysis, store in context
|
||||
prompt = ev.get("prompt")
|
||||
texts = ev.get("texts")
|
||||
tone_name = ev.get("tone_name")
|
||||
if not prompt or not isinstance(texts, list):
|
||||
raise ValueError(
|
||||
"StartEvent must contain 'prompt' (str) and 'texts' (list)"
|
||||
)
|
||||
|
||||
summarizer = TreeSummarize(verbose=False)
|
||||
analysis = await summarizer.aget_response(
|
||||
prompt, texts, tone_name=tone_name
|
||||
)
|
||||
await ctx.store.set("analysis", str(analysis))
|
||||
reflection = ""
|
||||
else:
|
||||
# Retry: reuse analysis from context
|
||||
analysis = await ctx.store.get("analysis")
|
||||
if not analysis:
|
||||
raise RuntimeError("Internal error: analysis not found in context")
|
||||
|
||||
wrong_output = ev.wrong_output
|
||||
if len(wrong_output) > 2000:
|
||||
wrong_output = wrong_output[:2000] + "... [truncated]"
|
||||
reflection = (
|
||||
f"\n\nYour previous response could not be parsed:\n{wrong_output}\n\n"
|
||||
f"Error:\n{ev.error}\n\n"
|
||||
"Please try again. Return ONLY valid JSON matching the schema above, "
|
||||
"with no markdown formatting or extra text."
|
||||
)
|
||||
|
||||
# Step 2: Format analysis as JSON using LLM completion
|
||||
format_instructions = self.output_parser.format(
|
||||
"Please structure the above information in the following JSON format:"
|
||||
)
|
||||
|
||||
json_prompt = STRUCTURED_RESPONSE_PROMPT_TEMPLATE.format(
|
||||
analysis=analysis,
|
||||
format_instructions=format_instructions + reflection,
|
||||
)
|
||||
|
||||
# Network retries handled by OpenAILike (max_retries=3)
|
||||
# response_format enables grammar-based constrained decoding on backends
|
||||
# that support it (DMR/llama.cpp, vLLM, Ollama, OpenAI).
|
||||
response = await Settings.llm.acomplete(
|
||||
json_prompt,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": self.output_cls.__name__,
|
||||
"schema": self.output_cls.model_json_schema(),
|
||||
},
|
||||
},
|
||||
)
|
||||
return ExtractionDone(output=response.text)
|
||||
|
||||
@step
|
||||
async def validate(
|
||||
self, ctx: Context, ev: ExtractionDone
|
||||
) -> StopEvent | ValidationErrorEvent:
|
||||
"""Validate extracted output against Pydantic schema."""
|
||||
raw_output = ev.output
|
||||
retries = await ctx.store.get("retries", default=0)
|
||||
|
||||
try:
|
||||
parsed = self.output_parser.parse(raw_output)
|
||||
if retries > 1:
|
||||
logger.info(
|
||||
f"LLM parse succeeded on attempt {retries}/{self.max_retries} "
|
||||
f"for {self.output_cls.__name__}"
|
||||
)
|
||||
return StopEvent(result={"success": parsed})
|
||||
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_msg = self._format_error(e, raw_output)
|
||||
await ctx.store.set("last_error", error_msg)
|
||||
|
||||
logger.error(
|
||||
f"LLM parse error (attempt {retries}/{self.max_retries}): "
|
||||
f"{type(e).__name__}: {e}\nRaw response: {raw_output[:500]}"
|
||||
)
|
||||
|
||||
return ValidationErrorEvent(
|
||||
error=error_msg,
|
||||
wrong_output=raw_output,
|
||||
)
|
||||
|
||||
def _format_error(self, error: Exception, raw_output: str) -> str:
|
||||
"""Format error for LLM feedback."""
|
||||
if isinstance(error, ValidationError):
|
||||
error_messages = []
|
||||
for err in error.errors():
|
||||
field = ".".join(str(loc) for loc in err["loc"])
|
||||
error_messages.append(f"- {err['msg']} in field '{field}'")
|
||||
return "Schema validation errors:\n" + "\n".join(error_messages)
|
||||
else:
|
||||
return f"Parse error: {str(error)}"
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(
|
||||
self, settings, temperature: float = 0.4, max_tokens: int | None = None
|
||||
@@ -225,7 +55,7 @@ class LLM:
|
||||
api_key=self.api_key,
|
||||
context_window=self.context_window,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=False,
|
||||
is_function_calling_model=True,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
timeout=self.settings_obj.LLM_REQUEST_TIMEOUT,
|
||||
@@ -248,36 +78,91 @@ class LLM:
|
||||
tone_name: str | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> T:
|
||||
"""Get structured output from LLM with validation retry via Workflow."""
|
||||
if timeout is None:
|
||||
timeout = self.settings_obj.LLM_STRUCTURED_RESPONSE_TIMEOUT
|
||||
"""Get structured output from LLM using tool-call with reflection retry.
|
||||
|
||||
async def run_workflow():
|
||||
workflow = StructuredOutputWorkflow(
|
||||
Uses astructured_predict (function-calling / tool-call mode) for the
|
||||
first attempt. On ValidationError or parse failure the wrong output
|
||||
and error are fed back as a reflection prompt and the call is retried
|
||||
up to LLM_PARSE_MAX_RETRIES times.
|
||||
|
||||
The outer retry() wrapper handles transient network errors with
|
||||
exponential back-off.
|
||||
"""
|
||||
max_retries = self.settings_obj.LLM_PARSE_MAX_RETRIES
|
||||
|
||||
async def _call_with_reflection():
|
||||
# Build full prompt: instruction + source texts
|
||||
if texts:
|
||||
texts_block = "\n\n".join(texts)
|
||||
full_prompt = f"{prompt}\n\n{texts_block}"
|
||||
else:
|
||||
full_prompt = prompt
|
||||
|
||||
prompt_tmpl = PromptTemplate("{user_prompt}")
|
||||
last_error: str | None = None
|
||||
|
||||
for attempt in range(1, max_retries + 2): # +2: first try + retries
|
||||
try:
|
||||
if attempt == 1:
|
||||
result = await Settings.llm.astructured_predict(
|
||||
output_cls, prompt_tmpl, user_prompt=full_prompt
|
||||
)
|
||||
else:
|
||||
reflection_tmpl = PromptTemplate(
|
||||
"{user_prompt}\n\n{reflection}"
|
||||
)
|
||||
result = await Settings.llm.astructured_predict(
|
||||
output_cls,
|
||||
reflection_tmpl,
|
||||
user_prompt=full_prompt,
|
||||
reflection=reflection,
|
||||
)
|
||||
|
||||
if attempt > 1:
|
||||
logger.info(
|
||||
f"LLM structured_predict succeeded on attempt "
|
||||
f"{attempt}/{max_retries + 1} for {output_cls.__name__}"
|
||||
)
|
||||
return result
|
||||
|
||||
except (ValidationError, ValueError) as e:
|
||||
wrong_output = str(e)
|
||||
if len(wrong_output) > 2000:
|
||||
wrong_output = wrong_output[:2000] + "... [truncated]"
|
||||
|
||||
last_error = self._format_validation_error(e)
|
||||
reflection = (
|
||||
f"Your previous response could not be parsed.\n\n"
|
||||
f"Error:\n{last_error}\n\n"
|
||||
"Please try again and return valid data matching the schema."
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"LLM parse error (attempt {attempt}/{max_retries + 1}): "
|
||||
f"{type(e).__name__}: {e}\n"
|
||||
f"Raw response: {wrong_output[:500]}"
|
||||
)
|
||||
|
||||
raise LLMParseError(
|
||||
output_cls=output_cls,
|
||||
max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1,
|
||||
timeout=timeout,
|
||||
error_msg=last_error or "Max retries exceeded",
|
||||
attempts=max_retries + 1,
|
||||
)
|
||||
|
||||
result = await workflow.run(
|
||||
prompt=prompt,
|
||||
texts=texts,
|
||||
tone_name=tone_name,
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
error_msg = result["error"] or "Max retries exceeded"
|
||||
raise LLMParseError(
|
||||
output_cls=output_cls,
|
||||
error_msg=error_msg,
|
||||
attempts=result.get("attempts", 0),
|
||||
)
|
||||
|
||||
return result["success"]
|
||||
|
||||
return await retry(run_workflow)(
|
||||
return await retry(_call_with_reflection)(
|
||||
retry_attempts=3,
|
||||
retry_backoff_interval=1.0,
|
||||
retry_backoff_max=30.0,
|
||||
retry_ignore_exc_types=(WorkflowTimeoutError,),
|
||||
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_validation_error(error: Exception) -> str:
|
||||
"""Format a validation/parse error for LLM reflection feedback."""
|
||||
if isinstance(error, ValidationError):
|
||||
error_messages = []
|
||||
for err in error.errors():
|
||||
field = ".".join(str(loc) for loc in err["loc"])
|
||||
error_messages.append(f"- {err['msg']} in field '{field}'")
|
||||
return "Schema validation errors:\n" + "\n".join(error_messages)
|
||||
return f"Parse error: {str(error)}"
|
||||
|
||||
@@ -14,10 +14,12 @@ class TopicResponse(BaseModel):
|
||||
title: str = Field(
|
||||
description="A descriptive title for the topic being discussed",
|
||||
validation_alias=AliasChoices("title", "Title"),
|
||||
min_length=8,
|
||||
)
|
||||
summary: str = Field(
|
||||
description="A concise 1-2 sentence summary of the discussion",
|
||||
validation_alias=AliasChoices("summary", "Summary"),
|
||||
min_length=8,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user