fix: switch structured output to tool-call with reflection retry (#879)

* fix: switch structured output to tool-call with reflection retry

Replace the two-pass StructuredOutputWorkflow (TreeSummarize → acomplete)
with astructured_predict + reflection retry loop for structured LLM output.

- Enable function-calling mode (is_function_calling_model=True)
- Use astructured_predict with PromptTemplate for first attempt
- On ValidationError/parse failure, retry with reflection feedback
- Add min_length=10 to TopicResponse title/summary fields
- Remove dead StructuredOutputWorkflow class and its event types
- Rewrite tests to match new astructured_predict approach

* fix: include texts parameter in astructured_predict prompt

The switch to astructured_predict dropped the texts parameter entirely,
causing summary prompts (participants, subjects, action items) to be
sent without the transcript content. Combine texts with the prompt
before calling astructured_predict, mirroring what TreeSummarize did.

* fix: reduce TopicResponse min_length from 10 to 8 for title and summary

* ci: try fixing spawning job in github

* ci: fix for new arm64 builder
This commit is contained in:
2026-02-24 18:28:11 -06:00
committed by GitHub
parent 815e87056d
commit 5d547586ef
4 changed files with 345 additions and 547 deletions

View File

@@ -34,7 +34,7 @@ jobs:
uv run -m pytest -v tests
docker-amd64:
runs-on: linux-amd64
runs-on: [linux-amd64]
concurrency:
group: docker-amd64-${{ github.ref }}
cancel-in-progress: true
@@ -52,12 +52,14 @@ jobs:
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
docker-arm64:
runs-on: linux-arm64
runs-on: [linux-arm64]
concurrency:
group: docker-arm64-${{ github.ref }}
cancel-in-progress: true
steps:
- uses: actions/checkout@v4
- name: Wait for Docker daemon
run: while ! docker version; do sleep 1; done
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build ARM64

View File

@@ -1,42 +1,23 @@
import logging
from contextvars import ContextVar
from typing import Generic, Type, TypeVar
from typing import Type, TypeVar
from uuid import uuid4
from llama_index.core import Settings
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.prompts import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.llms.openai_like import OpenAILike
from pydantic import BaseModel, ValidationError
from workflows.errors import WorkflowTimeoutError
from reflector.utils.retry import retry
T = TypeVar("T", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)
# Session ID for LiteLLM request grouping - set per processing run
llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None)
logger = logging.getLogger(__name__)
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
Based on the following analysis, provide the information in the requested JSON format:
Analysis:
{analysis}
{format_instructions}
"""
class LLMParseError(Exception):
"""Raised when LLM output cannot be parsed after retries."""
@@ -50,157 +31,6 @@ class LLMParseError(Exception):
)
class ExtractionDone(Event):
"""Event emitted when LLM JSON formatting completes."""
output: str
class ValidationErrorEvent(Event):
"""Event emitted when validation fails."""
error: str
wrong_output: str
class StructuredOutputWorkflow(Workflow, Generic[OutputT]):
"""Workflow for structured output extraction with validation retry.
This workflow handles parse/validation retries only. Network error retries
are handled internally by Settings.llm (OpenAILike max_retries=3).
The caller should NOT wrap this workflow in additional retry logic.
"""
def __init__(
self,
output_cls: Type[OutputT],
max_retries: int = 3,
**kwargs,
):
super().__init__(**kwargs)
self.output_cls: Type[OutputT] = output_cls
self.max_retries = max_retries
self.output_parser = PydanticOutputParser(output_cls)
@step
async def extract(
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
) -> StopEvent | ExtractionDone:
"""Extract structured data from text using two-step LLM process.
Step 1 (first call only): TreeSummarize generates text analysis
Step 2 (every call): Settings.llm.acomplete formats analysis as JSON
"""
current_retries = await ctx.store.get("retries", default=0)
await ctx.store.set("retries", current_retries + 1)
if current_retries >= self.max_retries:
last_error = await ctx.store.get("last_error", default=None)
logger.error(
f"Max retries ({self.max_retries}) reached for {self.output_cls.__name__}"
)
return StopEvent(result={"error": last_error, "attempts": current_retries})
if isinstance(ev, StartEvent):
# First call: run TreeSummarize to get analysis, store in context
prompt = ev.get("prompt")
texts = ev.get("texts")
tone_name = ev.get("tone_name")
if not prompt or not isinstance(texts, list):
raise ValueError(
"StartEvent must contain 'prompt' (str) and 'texts' (list)"
)
summarizer = TreeSummarize(verbose=False)
analysis = await summarizer.aget_response(
prompt, texts, tone_name=tone_name
)
await ctx.store.set("analysis", str(analysis))
reflection = ""
else:
# Retry: reuse analysis from context
analysis = await ctx.store.get("analysis")
if not analysis:
raise RuntimeError("Internal error: analysis not found in context")
wrong_output = ev.wrong_output
if len(wrong_output) > 2000:
wrong_output = wrong_output[:2000] + "... [truncated]"
reflection = (
f"\n\nYour previous response could not be parsed:\n{wrong_output}\n\n"
f"Error:\n{ev.error}\n\n"
"Please try again. Return ONLY valid JSON matching the schema above, "
"with no markdown formatting or extra text."
)
# Step 2: Format analysis as JSON using LLM completion
format_instructions = self.output_parser.format(
"Please structure the above information in the following JSON format:"
)
json_prompt = STRUCTURED_RESPONSE_PROMPT_TEMPLATE.format(
analysis=analysis,
format_instructions=format_instructions + reflection,
)
# Network retries handled by OpenAILike (max_retries=3)
# response_format enables grammar-based constrained decoding on backends
# that support it (DMR/llama.cpp, vLLM, Ollama, OpenAI).
response = await Settings.llm.acomplete(
json_prompt,
response_format={
"type": "json_schema",
"json_schema": {
"name": self.output_cls.__name__,
"schema": self.output_cls.model_json_schema(),
},
},
)
return ExtractionDone(output=response.text)
@step
async def validate(
self, ctx: Context, ev: ExtractionDone
) -> StopEvent | ValidationErrorEvent:
"""Validate extracted output against Pydantic schema."""
raw_output = ev.output
retries = await ctx.store.get("retries", default=0)
try:
parsed = self.output_parser.parse(raw_output)
if retries > 1:
logger.info(
f"LLM parse succeeded on attempt {retries}/{self.max_retries} "
f"for {self.output_cls.__name__}"
)
return StopEvent(result={"success": parsed})
except (ValidationError, ValueError) as e:
error_msg = self._format_error(e, raw_output)
await ctx.store.set("last_error", error_msg)
logger.error(
f"LLM parse error (attempt {retries}/{self.max_retries}): "
f"{type(e).__name__}: {e}\nRaw response: {raw_output[:500]}"
)
return ValidationErrorEvent(
error=error_msg,
wrong_output=raw_output,
)
def _format_error(self, error: Exception, raw_output: str) -> str:
"""Format error for LLM feedback."""
if isinstance(error, ValidationError):
error_messages = []
for err in error.errors():
field = ".".join(str(loc) for loc in err["loc"])
error_messages.append(f"- {err['msg']} in field '{field}'")
return "Schema validation errors:\n" + "\n".join(error_messages)
else:
return f"Parse error: {str(error)}"
class LLM:
def __init__(
self, settings, temperature: float = 0.4, max_tokens: int | None = None
@@ -225,7 +55,7 @@ class LLM:
api_key=self.api_key,
context_window=self.context_window,
is_chat_model=True,
is_function_calling_model=False,
is_function_calling_model=True,
temperature=self.temperature,
max_tokens=self.max_tokens,
timeout=self.settings_obj.LLM_REQUEST_TIMEOUT,
@@ -248,36 +78,91 @@ class LLM:
tone_name: str | None = None,
timeout: int | None = None,
) -> T:
"""Get structured output from LLM with validation retry via Workflow."""
if timeout is None:
timeout = self.settings_obj.LLM_STRUCTURED_RESPONSE_TIMEOUT
"""Get structured output from LLM using tool-call with reflection retry.
async def run_workflow():
workflow = StructuredOutputWorkflow(
Uses astructured_predict (function-calling / tool-call mode) for the
first attempt. On ValidationError or parse failure the wrong output
and error are fed back as a reflection prompt and the call is retried
up to LLM_PARSE_MAX_RETRIES times.
The outer retry() wrapper handles transient network errors with
exponential back-off.
"""
max_retries = self.settings_obj.LLM_PARSE_MAX_RETRIES
async def _call_with_reflection():
# Build full prompt: instruction + source texts
if texts:
texts_block = "\n\n".join(texts)
full_prompt = f"{prompt}\n\n{texts_block}"
else:
full_prompt = prompt
prompt_tmpl = PromptTemplate("{user_prompt}")
last_error: str | None = None
for attempt in range(1, max_retries + 2): # +2: first try + retries
try:
if attempt == 1:
result = await Settings.llm.astructured_predict(
output_cls, prompt_tmpl, user_prompt=full_prompt
)
else:
reflection_tmpl = PromptTemplate(
"{user_prompt}\n\n{reflection}"
)
result = await Settings.llm.astructured_predict(
output_cls,
reflection_tmpl,
user_prompt=full_prompt,
reflection=reflection,
)
if attempt > 1:
logger.info(
f"LLM structured_predict succeeded on attempt "
f"{attempt}/{max_retries + 1} for {output_cls.__name__}"
)
return result
except (ValidationError, ValueError) as e:
wrong_output = str(e)
if len(wrong_output) > 2000:
wrong_output = wrong_output[:2000] + "... [truncated]"
last_error = self._format_validation_error(e)
reflection = (
f"Your previous response could not be parsed.\n\n"
f"Error:\n{last_error}\n\n"
"Please try again and return valid data matching the schema."
)
logger.error(
f"LLM parse error (attempt {attempt}/{max_retries + 1}): "
f"{type(e).__name__}: {e}\n"
f"Raw response: {wrong_output[:500]}"
)
raise LLMParseError(
output_cls=output_cls,
max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1,
timeout=timeout,
error_msg=last_error or "Max retries exceeded",
attempts=max_retries + 1,
)
result = await workflow.run(
prompt=prompt,
texts=texts,
tone_name=tone_name,
)
if "error" in result:
error_msg = result["error"] or "Max retries exceeded"
raise LLMParseError(
output_cls=output_cls,
error_msg=error_msg,
attempts=result.get("attempts", 0),
)
return result["success"]
return await retry(run_workflow)(
return await retry(_call_with_reflection)(
retry_attempts=3,
retry_backoff_interval=1.0,
retry_backoff_max=30.0,
retry_ignore_exc_types=(WorkflowTimeoutError,),
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
)
@staticmethod
def _format_validation_error(error: Exception) -> str:
"""Format a validation/parse error for LLM reflection feedback."""
if isinstance(error, ValidationError):
error_messages = []
for err in error.errors():
field = ".".join(str(loc) for loc in err["loc"])
error_messages.append(f"- {err['msg']} in field '{field}'")
return "Schema validation errors:\n" + "\n".join(error_messages)
return f"Parse error: {str(error)}"

View File

@@ -14,10 +14,12 @@ class TopicResponse(BaseModel):
title: str = Field(
description="A descriptive title for the topic being discussed",
validation_alias=AliasChoices("title", "Title"),
min_length=8,
)
summary: str = Field(
description="A concise 1-2 sentence summary of the discussion",
validation_alias=AliasChoices("summary", "Summary"),
min_length=8,
)

View File

@@ -1,13 +1,11 @@
"""Tests for LLM parse error recovery using llama-index Workflow"""
"""Tests for LLM structured output with astructured_predict + reflection retry"""
from time import monotonic
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import BaseModel, Field
from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError
from pydantic import BaseModel, Field, ValidationError
from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow
from reflector.llm import LLM, LLMParseError
from reflector.utils.retry import RetryException
@@ -19,51 +17,43 @@ class TestResponse(BaseModel):
confidence: float = Field(description="Confidence score", ge=0, le=1)
def make_completion_response(text: str):
"""Create a mock CompletionResponse with .text attribute"""
response = MagicMock()
response.text = text
return response
class TestLLMParseErrorRecovery:
"""Test parse error recovery with Workflow feedback loop"""
"""Test parse error recovery with astructured_predict reflection loop"""
@pytest.mark.asyncio
async def test_parse_error_recovery_with_feedback(self, test_settings):
"""Test that parse errors trigger retry with error feedback"""
"""Test that parse errors trigger retry with reflection prompt"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
# TreeSummarize returns plain text analysis (step 1)
mock_summarizer.aget_response = AsyncMock(
return_value="The analysis shows a test with summary and high confidence."
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# First call: raise ValidationError (missing fields)
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {"title": "Test"},
}
],
)
else:
# Second call: should have reflection in the prompt
assert "reflection" in kwargs
assert "could not be parsed" in kwargs["reflection"]
assert "Error:" in kwargs["reflection"]
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
call_count = {"count": 0}
async def acomplete_handler(prompt, *args, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# First JSON formatting call returns invalid JSON
return make_completion_response('{"title": "Test"}')
else:
# Second call should have error feedback in prompt
assert "Your previous response could not be parsed:" in prompt
assert '{"title": "Test"}' in prompt
assert "Error:" in prompt
assert "Please try again" in prompt
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
@@ -71,8 +61,6 @@ class TestLLMParseErrorRecovery:
assert result.title == "Test"
assert result.summary == "Summary"
assert result.confidence == 0.95
# TreeSummarize called once, Settings.llm.acomplete called twice
assert mock_summarizer.aget_response.call_count == 1
assert call_count["count"] == 2
@pytest.mark.asyncio
@@ -80,56 +68,61 @@ class TestLLMParseErrorRecovery:
"""Test that parse error retry stops after max attempts"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Always return invalid JSON from acomplete
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response(
'{"invalid": "missing required fields"}'
)
# Always raise ValidationError
async def always_fail(output_cls, prompt_tmpl, **kwargs):
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {},
}
],
)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(side_effect=always_fail)
with pytest.raises(LLMParseError, match="Failed to parse"):
await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
# TreeSummarize called once, acomplete called max_retries times
assert mock_summarizer.aget_response.call_count == 1
assert mock_settings.llm.acomplete.call_count == expected_attempts
assert mock_settings.llm.astructured_predict.call_count == expected_attempts
@pytest.mark.asyncio
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
"""Test that raw response is logged when parse error occurs"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {"title": "Test"},
}
],
)
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
caplog.at_level("ERROR"),
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
call_count = {"count": 0}
async def acomplete_handler(*args, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
return make_completion_response('{"title": "Test"}') # Invalid
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
@@ -143,35 +136,45 @@ class TestLLMParseErrorRecovery:
@pytest.mark.asyncio
async def test_multiple_validation_errors_in_feedback(self, test_settings):
"""Test that validation errors are included in feedback"""
"""Test that validation errors are included in reflection feedback"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
call_count = {"count": 0}
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# Missing title and summary
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("title",),
"msg": "Field required",
"input": {},
},
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {},
},
],
)
else:
# Should have schema validation errors in reflection
assert "reflection" in kwargs
assert (
"Schema validation errors" in kwargs["reflection"]
or "error" in kwargs["reflection"].lower()
)
return TestResponse(title="Test", summary="Summary", confidence=0.95)
async def acomplete_handler(prompt, *args, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# Missing title and summary
return make_completion_response('{"confidence": 0.5}')
else:
# Should have schema validation errors in prompt
assert (
"Schema validation errors" in prompt
or "error" in prompt.lower()
)
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
@@ -185,17 +188,10 @@ class TestLLMParseErrorRecovery:
"""Test that no retry happens when first attempt succeeds"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
return_value=TestResponse(
title="Test", summary="Summary", confidence=0.95
)
)
@@ -206,195 +202,28 @@ class TestLLMParseErrorRecovery:
assert result.title == "Test"
assert result.summary == "Summary"
assert result.confidence == 0.95
assert mock_summarizer.aget_response.call_count == 1
assert mock_settings.llm.acomplete.call_count == 1
class TestStructuredOutputWorkflow:
"""Direct tests for the StructuredOutputWorkflow"""
@pytest.mark.asyncio
async def test_workflow_retries_on_validation_error(self):
"""Test workflow retries when validation fails"""
workflow = StructuredOutputWorkflow(
output_cls=TestResponse,
max_retries=3,
timeout=30,
)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
call_count = {"count": 0}
async def acomplete_handler(*args, **kwargs):
call_count["count"] += 1
if call_count["count"] < 2:
return make_completion_response('{"title": "Only title"}')
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.9}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
result = await workflow.run(
prompt="Extract data",
texts=["Some text"],
tone_name=None,
)
assert "success" in result
assert result["success"].title == "Test"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_workflow_returns_error_after_max_retries(self):
"""Test workflow returns error after exhausting retries"""
workflow = StructuredOutputWorkflow(
output_cls=TestResponse,
max_retries=2,
timeout=30,
)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Always return invalid JSON
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response('{"invalid": true}')
)
result = await workflow.run(
prompt="Extract data",
texts=["Some text"],
tone_name=None,
)
assert "error" in result
# TreeSummarize called once, acomplete called max_retries times
assert mock_summarizer.aget_response.call_count == 1
assert mock_settings.llm.acomplete.call_count == 2
assert mock_settings.llm.astructured_predict.call_count == 1
class TestNetworkErrorRetries:
"""Test that network error retries are handled by OpenAILike, not Workflow"""
"""Test that network errors are retried by the outer retry() wrapper"""
@pytest.mark.asyncio
async def test_network_error_propagates_after_openai_retries(self, test_settings):
"""Test that network errors are retried by OpenAILike and then propagate.
Network retries are handled by OpenAILike (max_retries=3), not by our
StructuredOutputWorkflow. This test verifies that network errors propagate
up after OpenAILike exhausts its retries.
"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Simulate network error from acomplete (after OpenAILike retries exhausted)
network_error = ConnectionError("Connection refused")
mock_settings.llm.acomplete = AsyncMock(side_effect=network_error)
# Network error wrapped in WorkflowRuntimeError
with pytest.raises(WorkflowRuntimeError, match="Connection refused"):
await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
# acomplete called only once - network error propagates, not retried by Workflow
assert mock_settings.llm.acomplete.call_count == 1
@pytest.mark.asyncio
async def test_network_error_not_retried_by_workflow(self, test_settings):
"""Test that Workflow does NOT retry network errors (OpenAILike handles those).
This verifies the separation of concerns:
- StructuredOutputWorkflow: retries parse/validation errors
- OpenAILike: retries network errors (internally, max_retries=3)
"""
workflow = StructuredOutputWorkflow(
output_cls=TestResponse,
max_retries=3,
timeout=30,
)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Network error should propagate immediately, not trigger Workflow retry
mock_settings.llm.acomplete = AsyncMock(
side_effect=TimeoutError("Request timed out")
)
# Network error wrapped in WorkflowRuntimeError
with pytest.raises(WorkflowRuntimeError, match="Request timed out"):
await workflow.run(
prompt="Extract data",
texts=["Some text"],
tone_name=None,
)
# Only called once - Workflow doesn't retry network errors
assert mock_settings.llm.acomplete.call_count == 1
class TestWorkflowTimeoutRetry:
"""Test timeout retry mechanism in get_structured_response"""
@pytest.mark.asyncio
async def test_timeout_retry_succeeds_on_retry(self, test_settings):
"""Test that WorkflowTimeoutError triggers retry and succeeds"""
async def test_network_error_retried_by_outer_wrapper(self, test_settings):
"""Test that network errors trigger the outer retry wrapper"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def workflow_run_side_effect(*args, **kwargs):
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
return {
"success": TestResponse(
title="Test", summary="Summary", confidence=0.95
)
}
raise ConnectionError("Connection refused")
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with (
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_workflow = MagicMock()
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
mock_workflow_class.return_value = mock_workflow
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
@@ -402,36 +231,16 @@ class TestWorkflowTimeoutRetry:
)
assert result.title == "Test"
assert result.summary == "Summary"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_timeout_retry_exhausts_after_max_attempts(self, test_settings):
"""Test that timeout retry stops after max attempts"""
async def test_network_error_exhausts_retries(self, test_settings):
"""Test that persistent network errors exhaust retry attempts"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def workflow_run_side_effect(*args, **kwargs):
call_count["count"] += 1
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
with (
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_workflow = MagicMock()
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
mock_workflow_class.return_value = mock_workflow
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=ConnectionError("Connection refused")
)
with pytest.raises(RetryException, match="Retry attempts exceeded"):
@@ -439,41 +248,129 @@ class TestWorkflowTimeoutRetry:
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert call_count["count"] == 3
# 3 retry attempts
assert mock_settings.llm.astructured_predict.call_count == 3
class TestTextsInclusion:
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
@pytest.mark.asyncio
async def test_timeout_retry_with_backoff(self, test_settings):
"""Test that exponential backoff is applied between retries"""
async def test_texts_included_in_prompt(self, test_settings):
"""Test that texts content is appended to the prompt for astructured_predict"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_times = []
captured_prompts = []
async def workflow_run_side_effect(*args, **kwargs):
call_times.append(monotonic())
if len(call_times) < 3:
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
return {
"success": TestResponse(
title="Test", summary="Summary", confidence=0.95
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
captured_prompts.append(kwargs.get("user_prompt", ""))
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=capture_prompt
)
await llm.get_structured_response(
prompt="Identify all participants",
texts=["Alice: Hello everyone", "Bob: Hi Alice"],
output_cls=TestResponse,
)
assert len(captured_prompts) == 1
prompt_sent = captured_prompts[0]
assert "Identify all participants" in prompt_sent
assert "Alice: Hello everyone" in prompt_sent
assert "Bob: Hi Alice" in prompt_sent
@pytest.mark.asyncio
async def test_empty_texts_uses_prompt_only(self, test_settings):
"""Test that empty texts list sends only the prompt"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
captured_prompts = []
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
captured_prompts.append(kwargs.get("user_prompt", ""))
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=capture_prompt
)
await llm.get_structured_response(
prompt="Identify all participants",
texts=[],
output_cls=TestResponse,
)
assert len(captured_prompts) == 1
assert captured_prompts[0] == "Identify all participants"
@pytest.mark.asyncio
async def test_texts_included_in_reflection_retry(self, test_settings):
"""Test that texts are included in the prompt even during reflection retries"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
captured_prompts = []
call_count = {"count": 0}
async def capture_and_fail_first(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
captured_prompts.append(kwargs.get("user_prompt", ""))
if call_count["count"] == 1:
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {},
}
],
)
}
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with (
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_workflow = MagicMock()
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
mock_workflow_class.return_value = mock_workflow
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=capture_and_fail_first
)
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
await llm.get_structured_response(
prompt="Summarize this",
texts=["The meeting covered project updates"],
output_cls=TestResponse,
)
# Both first attempt and reflection retry should include the texts
assert len(captured_prompts) == 2
for prompt_sent in captured_prompts:
assert "Summarize this" in prompt_sent
assert "The meeting covered project updates" in prompt_sent
class TestReflectionRetryBackoff:
"""Test the reflection retry timing behavior"""
@pytest.mark.asyncio
async def test_value_error_triggers_reflection(self, test_settings):
"""Test that ValueError (parse failure) also triggers reflection retry"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
raise ValueError("Could not parse output")
assert "reflection" in kwargs
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
@@ -481,8 +378,20 @@ class TestWorkflowTimeoutRetry:
)
assert result.title == "Test"
if len(call_times) >= 2:
time_between_calls = call_times[1] - call_times[0]
assert (
time_between_calls >= 1.5
), f"Expected ~2s backoff, got {time_between_calls}s"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_format_validation_error_method(self, test_settings):
"""Test _format_validation_error produces correct feedback"""
# ValidationError
try:
TestResponse(title="x", summary="y", confidence=5.0) # confidence > 1
except ValidationError as e:
result = LLM._format_validation_error(e)
assert "Schema validation errors" in result
assert "confidence" in result
# ValueError
result = LLM._format_validation_error(ValueError("bad input"))
assert "Parse error:" in result
assert "bad input" in result