mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
fix: switch structured output to tool-call with reflection retry (#879)
* fix: switch structured output to tool-call with reflection retry Replace the two-pass StructuredOutputWorkflow (TreeSummarize → acomplete) with astructured_predict + reflection retry loop for structured LLM output. - Enable function-calling mode (is_function_calling_model=True) - Use astructured_predict with PromptTemplate for first attempt - On ValidationError/parse failure, retry with reflection feedback - Add min_length=10 to TopicResponse title/summary fields - Remove dead StructuredOutputWorkflow class and its event types - Rewrite tests to match new astructured_predict approach * fix: include texts parameter in astructured_predict prompt The switch to astructured_predict dropped the texts parameter entirely, causing summary prompts (participants, subjects, action items) to be sent without the transcript content. Combine texts with the prompt before calling astructured_predict, mirroring what TreeSummarize did. * fix: reduce TopicResponse min_length from 10 to 8 for title and summary * ci: try fixing spawning job in github * ci: fix for new arm64 builder
This commit is contained in:
6
.github/workflows/test_server.yml
vendored
6
.github/workflows/test_server.yml
vendored
@@ -34,7 +34,7 @@ jobs:
|
|||||||
uv run -m pytest -v tests
|
uv run -m pytest -v tests
|
||||||
|
|
||||||
docker-amd64:
|
docker-amd64:
|
||||||
runs-on: linux-amd64
|
runs-on: [linux-amd64]
|
||||||
concurrency:
|
concurrency:
|
||||||
group: docker-amd64-${{ github.ref }}
|
group: docker-amd64-${{ github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
@@ -52,12 +52,14 @@ jobs:
|
|||||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||||
|
|
||||||
docker-arm64:
|
docker-arm64:
|
||||||
runs-on: linux-arm64
|
runs-on: [linux-arm64]
|
||||||
concurrency:
|
concurrency:
|
||||||
group: docker-arm64-${{ github.ref }}
|
group: docker-arm64-${{ github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
- name: Wait for Docker daemon
|
||||||
|
run: while ! docker version; do sleep 1; done
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build ARM64
|
- name: Build ARM64
|
||||||
|
|||||||
@@ -1,42 +1,23 @@
|
|||||||
import logging
|
import logging
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Generic, Type, TypeVar
|
from typing import Type, TypeVar
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from llama_index.core import Settings
|
from llama_index.core import Settings
|
||||||
from llama_index.core.output_parsers import PydanticOutputParser
|
from llama_index.core.prompts import PromptTemplate
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
from llama_index.core.response_synthesizers import TreeSummarize
|
||||||
from llama_index.core.workflow import (
|
|
||||||
Context,
|
|
||||||
Event,
|
|
||||||
StartEvent,
|
|
||||||
StopEvent,
|
|
||||||
Workflow,
|
|
||||||
step,
|
|
||||||
)
|
|
||||||
from llama_index.llms.openai_like import OpenAILike
|
from llama_index.llms.openai_like import OpenAILike
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from workflows.errors import WorkflowTimeoutError
|
|
||||||
|
|
||||||
from reflector.utils.retry import retry
|
from reflector.utils.retry import retry
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
OutputT = TypeVar("OutputT", bound=BaseModel)
|
|
||||||
|
|
||||||
# Session ID for LiteLLM request grouping - set per processing run
|
# Session ID for LiteLLM request grouping - set per processing run
|
||||||
llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None)
|
llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
|
|
||||||
Based on the following analysis, provide the information in the requested JSON format:
|
|
||||||
|
|
||||||
Analysis:
|
|
||||||
{analysis}
|
|
||||||
|
|
||||||
{format_instructions}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class LLMParseError(Exception):
|
class LLMParseError(Exception):
|
||||||
"""Raised when LLM output cannot be parsed after retries."""
|
"""Raised when LLM output cannot be parsed after retries."""
|
||||||
@@ -50,157 +31,6 @@ class LLMParseError(Exception):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtractionDone(Event):
|
|
||||||
"""Event emitted when LLM JSON formatting completes."""
|
|
||||||
|
|
||||||
output: str
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationErrorEvent(Event):
|
|
||||||
"""Event emitted when validation fails."""
|
|
||||||
|
|
||||||
error: str
|
|
||||||
wrong_output: str
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredOutputWorkflow(Workflow, Generic[OutputT]):
|
|
||||||
"""Workflow for structured output extraction with validation retry.
|
|
||||||
|
|
||||||
This workflow handles parse/validation retries only. Network error retries
|
|
||||||
are handled internally by Settings.llm (OpenAILike max_retries=3).
|
|
||||||
The caller should NOT wrap this workflow in additional retry logic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
output_cls: Type[OutputT],
|
|
||||||
max_retries: int = 3,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.output_cls: Type[OutputT] = output_cls
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.output_parser = PydanticOutputParser(output_cls)
|
|
||||||
|
|
||||||
@step
|
|
||||||
async def extract(
|
|
||||||
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
|
|
||||||
) -> StopEvent | ExtractionDone:
|
|
||||||
"""Extract structured data from text using two-step LLM process.
|
|
||||||
|
|
||||||
Step 1 (first call only): TreeSummarize generates text analysis
|
|
||||||
Step 2 (every call): Settings.llm.acomplete formats analysis as JSON
|
|
||||||
"""
|
|
||||||
current_retries = await ctx.store.get("retries", default=0)
|
|
||||||
await ctx.store.set("retries", current_retries + 1)
|
|
||||||
|
|
||||||
if current_retries >= self.max_retries:
|
|
||||||
last_error = await ctx.store.get("last_error", default=None)
|
|
||||||
logger.error(
|
|
||||||
f"Max retries ({self.max_retries}) reached for {self.output_cls.__name__}"
|
|
||||||
)
|
|
||||||
return StopEvent(result={"error": last_error, "attempts": current_retries})
|
|
||||||
|
|
||||||
if isinstance(ev, StartEvent):
|
|
||||||
# First call: run TreeSummarize to get analysis, store in context
|
|
||||||
prompt = ev.get("prompt")
|
|
||||||
texts = ev.get("texts")
|
|
||||||
tone_name = ev.get("tone_name")
|
|
||||||
if not prompt or not isinstance(texts, list):
|
|
||||||
raise ValueError(
|
|
||||||
"StartEvent must contain 'prompt' (str) and 'texts' (list)"
|
|
||||||
)
|
|
||||||
|
|
||||||
summarizer = TreeSummarize(verbose=False)
|
|
||||||
analysis = await summarizer.aget_response(
|
|
||||||
prompt, texts, tone_name=tone_name
|
|
||||||
)
|
|
||||||
await ctx.store.set("analysis", str(analysis))
|
|
||||||
reflection = ""
|
|
||||||
else:
|
|
||||||
# Retry: reuse analysis from context
|
|
||||||
analysis = await ctx.store.get("analysis")
|
|
||||||
if not analysis:
|
|
||||||
raise RuntimeError("Internal error: analysis not found in context")
|
|
||||||
|
|
||||||
wrong_output = ev.wrong_output
|
|
||||||
if len(wrong_output) > 2000:
|
|
||||||
wrong_output = wrong_output[:2000] + "... [truncated]"
|
|
||||||
reflection = (
|
|
||||||
f"\n\nYour previous response could not be parsed:\n{wrong_output}\n\n"
|
|
||||||
f"Error:\n{ev.error}\n\n"
|
|
||||||
"Please try again. Return ONLY valid JSON matching the schema above, "
|
|
||||||
"with no markdown formatting or extra text."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2: Format analysis as JSON using LLM completion
|
|
||||||
format_instructions = self.output_parser.format(
|
|
||||||
"Please structure the above information in the following JSON format:"
|
|
||||||
)
|
|
||||||
|
|
||||||
json_prompt = STRUCTURED_RESPONSE_PROMPT_TEMPLATE.format(
|
|
||||||
analysis=analysis,
|
|
||||||
format_instructions=format_instructions + reflection,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Network retries handled by OpenAILike (max_retries=3)
|
|
||||||
# response_format enables grammar-based constrained decoding on backends
|
|
||||||
# that support it (DMR/llama.cpp, vLLM, Ollama, OpenAI).
|
|
||||||
response = await Settings.llm.acomplete(
|
|
||||||
json_prompt,
|
|
||||||
response_format={
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {
|
|
||||||
"name": self.output_cls.__name__,
|
|
||||||
"schema": self.output_cls.model_json_schema(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return ExtractionDone(output=response.text)
|
|
||||||
|
|
||||||
@step
|
|
||||||
async def validate(
|
|
||||||
self, ctx: Context, ev: ExtractionDone
|
|
||||||
) -> StopEvent | ValidationErrorEvent:
|
|
||||||
"""Validate extracted output against Pydantic schema."""
|
|
||||||
raw_output = ev.output
|
|
||||||
retries = await ctx.store.get("retries", default=0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed = self.output_parser.parse(raw_output)
|
|
||||||
if retries > 1:
|
|
||||||
logger.info(
|
|
||||||
f"LLM parse succeeded on attempt {retries}/{self.max_retries} "
|
|
||||||
f"for {self.output_cls.__name__}"
|
|
||||||
)
|
|
||||||
return StopEvent(result={"success": parsed})
|
|
||||||
|
|
||||||
except (ValidationError, ValueError) as e:
|
|
||||||
error_msg = self._format_error(e, raw_output)
|
|
||||||
await ctx.store.set("last_error", error_msg)
|
|
||||||
|
|
||||||
logger.error(
|
|
||||||
f"LLM parse error (attempt {retries}/{self.max_retries}): "
|
|
||||||
f"{type(e).__name__}: {e}\nRaw response: {raw_output[:500]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ValidationErrorEvent(
|
|
||||||
error=error_msg,
|
|
||||||
wrong_output=raw_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _format_error(self, error: Exception, raw_output: str) -> str:
|
|
||||||
"""Format error for LLM feedback."""
|
|
||||||
if isinstance(error, ValidationError):
|
|
||||||
error_messages = []
|
|
||||||
for err in error.errors():
|
|
||||||
field = ".".join(str(loc) for loc in err["loc"])
|
|
||||||
error_messages.append(f"- {err['msg']} in field '{field}'")
|
|
||||||
return "Schema validation errors:\n" + "\n".join(error_messages)
|
|
||||||
else:
|
|
||||||
return f"Parse error: {str(error)}"
|
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, settings, temperature: float = 0.4, max_tokens: int | None = None
|
self, settings, temperature: float = 0.4, max_tokens: int | None = None
|
||||||
@@ -225,7 +55,7 @@ class LLM:
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
context_window=self.context_window,
|
context_window=self.context_window,
|
||||||
is_chat_model=True,
|
is_chat_model=True,
|
||||||
is_function_calling_model=False,
|
is_function_calling_model=True,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
timeout=self.settings_obj.LLM_REQUEST_TIMEOUT,
|
timeout=self.settings_obj.LLM_REQUEST_TIMEOUT,
|
||||||
@@ -248,36 +78,91 @@ class LLM:
|
|||||||
tone_name: str | None = None,
|
tone_name: str | None = None,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Get structured output from LLM with validation retry via Workflow."""
|
"""Get structured output from LLM using tool-call with reflection retry.
|
||||||
if timeout is None:
|
|
||||||
timeout = self.settings_obj.LLM_STRUCTURED_RESPONSE_TIMEOUT
|
|
||||||
|
|
||||||
async def run_workflow():
|
Uses astructured_predict (function-calling / tool-call mode) for the
|
||||||
workflow = StructuredOutputWorkflow(
|
first attempt. On ValidationError or parse failure the wrong output
|
||||||
|
and error are fed back as a reflection prompt and the call is retried
|
||||||
|
up to LLM_PARSE_MAX_RETRIES times.
|
||||||
|
|
||||||
|
The outer retry() wrapper handles transient network errors with
|
||||||
|
exponential back-off.
|
||||||
|
"""
|
||||||
|
max_retries = self.settings_obj.LLM_PARSE_MAX_RETRIES
|
||||||
|
|
||||||
|
async def _call_with_reflection():
|
||||||
|
# Build full prompt: instruction + source texts
|
||||||
|
if texts:
|
||||||
|
texts_block = "\n\n".join(texts)
|
||||||
|
full_prompt = f"{prompt}\n\n{texts_block}"
|
||||||
|
else:
|
||||||
|
full_prompt = prompt
|
||||||
|
|
||||||
|
prompt_tmpl = PromptTemplate("{user_prompt}")
|
||||||
|
last_error: str | None = None
|
||||||
|
|
||||||
|
for attempt in range(1, max_retries + 2): # +2: first try + retries
|
||||||
|
try:
|
||||||
|
if attempt == 1:
|
||||||
|
result = await Settings.llm.astructured_predict(
|
||||||
|
output_cls, prompt_tmpl, user_prompt=full_prompt
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reflection_tmpl = PromptTemplate(
|
||||||
|
"{user_prompt}\n\n{reflection}"
|
||||||
|
)
|
||||||
|
result = await Settings.llm.astructured_predict(
|
||||||
|
output_cls,
|
||||||
|
reflection_tmpl,
|
||||||
|
user_prompt=full_prompt,
|
||||||
|
reflection=reflection,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attempt > 1:
|
||||||
|
logger.info(
|
||||||
|
f"LLM structured_predict succeeded on attempt "
|
||||||
|
f"{attempt}/{max_retries + 1} for {output_cls.__name__}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except (ValidationError, ValueError) as e:
|
||||||
|
wrong_output = str(e)
|
||||||
|
if len(wrong_output) > 2000:
|
||||||
|
wrong_output = wrong_output[:2000] + "... [truncated]"
|
||||||
|
|
||||||
|
last_error = self._format_validation_error(e)
|
||||||
|
reflection = (
|
||||||
|
f"Your previous response could not be parsed.\n\n"
|
||||||
|
f"Error:\n{last_error}\n\n"
|
||||||
|
"Please try again and return valid data matching the schema."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"LLM parse error (attempt {attempt}/{max_retries + 1}): "
|
||||||
|
f"{type(e).__name__}: {e}\n"
|
||||||
|
f"Raw response: {wrong_output[:500]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise LLMParseError(
|
||||||
output_cls=output_cls,
|
output_cls=output_cls,
|
||||||
max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1,
|
error_msg=last_error or "Max retries exceeded",
|
||||||
timeout=timeout,
|
attempts=max_retries + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await workflow.run(
|
return await retry(_call_with_reflection)(
|
||||||
prompt=prompt,
|
|
||||||
texts=texts,
|
|
||||||
tone_name=tone_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "error" in result:
|
|
||||||
error_msg = result["error"] or "Max retries exceeded"
|
|
||||||
raise LLMParseError(
|
|
||||||
output_cls=output_cls,
|
|
||||||
error_msg=error_msg,
|
|
||||||
attempts=result.get("attempts", 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
return result["success"]
|
|
||||||
|
|
||||||
return await retry(run_workflow)(
|
|
||||||
retry_attempts=3,
|
retry_attempts=3,
|
||||||
retry_backoff_interval=1.0,
|
retry_backoff_interval=1.0,
|
||||||
retry_backoff_max=30.0,
|
retry_backoff_max=30.0,
|
||||||
retry_ignore_exc_types=(WorkflowTimeoutError,),
|
retry_ignore_exc_types=(ConnectionError, TimeoutError, OSError),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_validation_error(error: Exception) -> str:
|
||||||
|
"""Format a validation/parse error for LLM reflection feedback."""
|
||||||
|
if isinstance(error, ValidationError):
|
||||||
|
error_messages = []
|
||||||
|
for err in error.errors():
|
||||||
|
field = ".".join(str(loc) for loc in err["loc"])
|
||||||
|
error_messages.append(f"- {err['msg']} in field '{field}'")
|
||||||
|
return "Schema validation errors:\n" + "\n".join(error_messages)
|
||||||
|
return f"Parse error: {str(error)}"
|
||||||
|
|||||||
@@ -14,10 +14,12 @@ class TopicResponse(BaseModel):
|
|||||||
title: str = Field(
|
title: str = Field(
|
||||||
description="A descriptive title for the topic being discussed",
|
description="A descriptive title for the topic being discussed",
|
||||||
validation_alias=AliasChoices("title", "Title"),
|
validation_alias=AliasChoices("title", "Title"),
|
||||||
|
min_length=8,
|
||||||
)
|
)
|
||||||
summary: str = Field(
|
summary: str = Field(
|
||||||
description="A concise 1-2 sentence summary of the discussion",
|
description="A concise 1-2 sentence summary of the discussion",
|
||||||
validation_alias=AliasChoices("summary", "Summary"),
|
validation_alias=AliasChoices("summary", "Summary"),
|
||||||
|
min_length=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
"""Tests for LLM parse error recovery using llama-index Workflow"""
|
"""Tests for LLM structured output with astructured_predict + reflection retry"""
|
||||||
|
|
||||||
from time import monotonic
|
from unittest.mock import AsyncMock, patch
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError
|
|
||||||
|
|
||||||
from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow
|
from reflector.llm import LLM, LLMParseError
|
||||||
from reflector.utils.retry import RetryException
|
from reflector.utils.retry import RetryException
|
||||||
|
|
||||||
|
|
||||||
@@ -19,51 +17,43 @@ class TestResponse(BaseModel):
|
|||||||
confidence: float = Field(description="Confidence score", ge=0, le=1)
|
confidence: float = Field(description="Confidence score", ge=0, le=1)
|
||||||
|
|
||||||
|
|
||||||
def make_completion_response(text: str):
|
|
||||||
"""Create a mock CompletionResponse with .text attribute"""
|
|
||||||
response = MagicMock()
|
|
||||||
response.text = text
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMParseErrorRecovery:
|
class TestLLMParseErrorRecovery:
|
||||||
"""Test parse error recovery with Workflow feedback loop"""
|
"""Test parse error recovery with astructured_predict reflection loop"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_error_recovery_with_feedback(self, test_settings):
|
async def test_parse_error_recovery_with_feedback(self, test_settings):
|
||||||
"""Test that parse errors trigger retry with error feedback"""
|
"""Test that parse errors trigger retry with reflection prompt"""
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
with (
|
call_count = {"count": 0}
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||||
):
|
call_count["count"] += 1
|
||||||
mock_summarizer = MagicMock()
|
if call_count["count"] == 1:
|
||||||
mock_summarize.return_value = mock_summarizer
|
# First call: raise ValidationError (missing fields)
|
||||||
# TreeSummarize returns plain text analysis (step 1)
|
raise ValidationError.from_exception_data(
|
||||||
mock_summarizer.aget_response = AsyncMock(
|
title="TestResponse",
|
||||||
return_value="The analysis shows a test with summary and high confidence."
|
line_errors=[
|
||||||
|
{
|
||||||
|
"type": "missing",
|
||||||
|
"loc": ("summary",),
|
||||||
|
"msg": "Field required",
|
||||||
|
"input": {"title": "Test"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Second call: should have reflection in the prompt
|
||||||
|
assert "reflection" in kwargs
|
||||||
|
assert "could not be parsed" in kwargs["reflection"]
|
||||||
|
assert "Error:" in kwargs["reflection"]
|
||||||
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
|
|
||||||
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
|
side_effect=astructured_predict_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
call_count = {"count": 0}
|
|
||||||
|
|
||||||
async def acomplete_handler(prompt, *args, **kwargs):
|
|
||||||
call_count["count"] += 1
|
|
||||||
if call_count["count"] == 1:
|
|
||||||
# First JSON formatting call returns invalid JSON
|
|
||||||
return make_completion_response('{"title": "Test"}')
|
|
||||||
else:
|
|
||||||
# Second call should have error feedback in prompt
|
|
||||||
assert "Your previous response could not be parsed:" in prompt
|
|
||||||
assert '{"title": "Test"}' in prompt
|
|
||||||
assert "Error:" in prompt
|
|
||||||
assert "Please try again" in prompt
|
|
||||||
return make_completion_response(
|
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
||||||
|
|
||||||
result = await llm.get_structured_response(
|
result = await llm.get_structured_response(
|
||||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||||
)
|
)
|
||||||
@@ -71,8 +61,6 @@ class TestLLMParseErrorRecovery:
|
|||||||
assert result.title == "Test"
|
assert result.title == "Test"
|
||||||
assert result.summary == "Summary"
|
assert result.summary == "Summary"
|
||||||
assert result.confidence == 0.95
|
assert result.confidence == 0.95
|
||||||
# TreeSummarize called once, Settings.llm.acomplete called twice
|
|
||||||
assert mock_summarizer.aget_response.call_count == 1
|
|
||||||
assert call_count["count"] == 2
|
assert call_count["count"] == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -80,56 +68,61 @@ class TestLLMParseErrorRecovery:
|
|||||||
"""Test that parse error retry stops after max attempts"""
|
"""Test that parse error retry stops after max attempts"""
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
with (
|
# Always raise ValidationError
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
async def always_fail(output_cls, prompt_tmpl, **kwargs):
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
raise ValidationError.from_exception_data(
|
||||||
):
|
title="TestResponse",
|
||||||
mock_summarizer = MagicMock()
|
line_errors=[
|
||||||
mock_summarize.return_value = mock_summarizer
|
{
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
"type": "missing",
|
||||||
|
"loc": ("summary",),
|
||||||
# Always return invalid JSON from acomplete
|
"msg": "Field required",
|
||||||
mock_settings.llm.acomplete = AsyncMock(
|
"input": {},
|
||||||
return_value=make_completion_response(
|
}
|
||||||
'{"invalid": "missing required fields"}'
|
],
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
|
mock_settings.llm.astructured_predict = AsyncMock(side_effect=always_fail)
|
||||||
|
|
||||||
with pytest.raises(LLMParseError, match="Failed to parse"):
|
with pytest.raises(LLMParseError, match="Failed to parse"):
|
||||||
await llm.get_structured_response(
|
await llm.get_structured_response(
|
||||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
|
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
|
||||||
# TreeSummarize called once, acomplete called max_retries times
|
assert mock_settings.llm.astructured_predict.call_count == expected_attempts
|
||||||
assert mock_summarizer.aget_response.call_count == 1
|
|
||||||
assert mock_settings.llm.acomplete.call_count == expected_attempts
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
|
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
|
||||||
"""Test that raw response is logged when parse error occurs"""
|
"""Test that raw response is logged when parse error occurs"""
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
|
call_count = {"count": 0}
|
||||||
|
|
||||||
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||||
|
call_count["count"] += 1
|
||||||
|
if call_count["count"] == 1:
|
||||||
|
raise ValidationError.from_exception_data(
|
||||||
|
title="TestResponse",
|
||||||
|
line_errors=[
|
||||||
|
{
|
||||||
|
"type": "missing",
|
||||||
|
"loc": ("summary",),
|
||||||
|
"msg": "Field required",
|
||||||
|
"input": {"title": "Test"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
patch("reflector.llm.Settings") as mock_settings,
|
||||||
caplog.at_level("ERROR"),
|
caplog.at_level("ERROR"),
|
||||||
):
|
):
|
||||||
mock_summarizer = MagicMock()
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
mock_summarize.return_value = mock_summarizer
|
side_effect=astructured_predict_handler
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
)
|
||||||
|
|
||||||
call_count = {"count": 0}
|
|
||||||
|
|
||||||
async def acomplete_handler(*args, **kwargs):
|
|
||||||
call_count["count"] += 1
|
|
||||||
if call_count["count"] == 1:
|
|
||||||
return make_completion_response('{"title": "Test"}') # Invalid
|
|
||||||
return make_completion_response(
|
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
||||||
|
|
||||||
result = await llm.get_structured_response(
|
result = await llm.get_structured_response(
|
||||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||||
@@ -143,35 +136,45 @@ class TestLLMParseErrorRecovery:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multiple_validation_errors_in_feedback(self, test_settings):
|
async def test_multiple_validation_errors_in_feedback(self, test_settings):
|
||||||
"""Test that validation errors are included in feedback"""
|
"""Test that validation errors are included in reflection feedback"""
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
with (
|
call_count = {"count": 0}
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
|
||||||
):
|
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
|
|
||||||
call_count = {"count": 0}
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||||
|
call_count["count"] += 1
|
||||||
|
if call_count["count"] == 1:
|
||||||
|
# Missing title and summary
|
||||||
|
raise ValidationError.from_exception_data(
|
||||||
|
title="TestResponse",
|
||||||
|
line_errors=[
|
||||||
|
{
|
||||||
|
"type": "missing",
|
||||||
|
"loc": ("title",),
|
||||||
|
"msg": "Field required",
|
||||||
|
"input": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "missing",
|
||||||
|
"loc": ("summary",),
|
||||||
|
"msg": "Field required",
|
||||||
|
"input": {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Should have schema validation errors in reflection
|
||||||
|
assert "reflection" in kwargs
|
||||||
|
assert (
|
||||||
|
"Schema validation errors" in kwargs["reflection"]
|
||||||
|
or "error" in kwargs["reflection"].lower()
|
||||||
|
)
|
||||||
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
|
|
||||||
async def acomplete_handler(prompt, *args, **kwargs):
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
call_count["count"] += 1
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
if call_count["count"] == 1:
|
side_effect=astructured_predict_handler
|
||||||
# Missing title and summary
|
)
|
||||||
return make_completion_response('{"confidence": 0.5}')
|
|
||||||
else:
|
|
||||||
# Should have schema validation errors in prompt
|
|
||||||
assert (
|
|
||||||
"Schema validation errors" in prompt
|
|
||||||
or "error" in prompt.lower()
|
|
||||||
)
|
|
||||||
return make_completion_response(
|
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
||||||
|
|
||||||
result = await llm.get_structured_response(
|
result = await llm.get_structured_response(
|
||||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||||
@@ -185,17 +188,10 @@ class TestLLMParseErrorRecovery:
|
|||||||
"""Test that no retry happens when first attempt succeeds"""
|
"""Test that no retry happens when first attempt succeeds"""
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
with (
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
return_value=TestResponse(
|
||||||
):
|
title="Test", summary="Summary", confidence=0.95
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(
|
|
||||||
return_value=make_completion_response(
|
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -206,195 +202,28 @@ class TestLLMParseErrorRecovery:
|
|||||||
assert result.title == "Test"
|
assert result.title == "Test"
|
||||||
assert result.summary == "Summary"
|
assert result.summary == "Summary"
|
||||||
assert result.confidence == 0.95
|
assert result.confidence == 0.95
|
||||||
assert mock_summarizer.aget_response.call_count == 1
|
assert mock_settings.llm.astructured_predict.call_count == 1
|
||||||
assert mock_settings.llm.acomplete.call_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestStructuredOutputWorkflow:
|
|
||||||
"""Direct tests for the StructuredOutputWorkflow"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_workflow_retries_on_validation_error(self):
|
|
||||||
"""Test workflow retries when validation fails"""
|
|
||||||
workflow = StructuredOutputWorkflow(
|
|
||||||
output_cls=TestResponse,
|
|
||||||
max_retries=3,
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
|
||||||
):
|
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
|
|
||||||
call_count = {"count": 0}
|
|
||||||
|
|
||||||
async def acomplete_handler(*args, **kwargs):
|
|
||||||
call_count["count"] += 1
|
|
||||||
if call_count["count"] < 2:
|
|
||||||
return make_completion_response('{"title": "Only title"}')
|
|
||||||
return make_completion_response(
|
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.9}'
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
||||||
|
|
||||||
result = await workflow.run(
|
|
||||||
prompt="Extract data",
|
|
||||||
texts=["Some text"],
|
|
||||||
tone_name=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "success" in result
|
|
||||||
assert result["success"].title == "Test"
|
|
||||||
assert call_count["count"] == 2
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_workflow_returns_error_after_max_retries(self):
|
|
||||||
"""Test workflow returns error after exhausting retries"""
|
|
||||||
workflow = StructuredOutputWorkflow(
|
|
||||||
output_cls=TestResponse,
|
|
||||||
max_retries=2,
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
|
||||||
):
|
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
|
|
||||||
# Always return invalid JSON
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(
|
|
||||||
return_value=make_completion_response('{"invalid": true}')
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await workflow.run(
|
|
||||||
prompt="Extract data",
|
|
||||||
texts=["Some text"],
|
|
||||||
tone_name=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "error" in result
|
|
||||||
# TreeSummarize called once, acomplete called max_retries times
|
|
||||||
assert mock_summarizer.aget_response.call_count == 1
|
|
||||||
assert mock_settings.llm.acomplete.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
class TestNetworkErrorRetries:
|
class TestNetworkErrorRetries:
|
||||||
"""Test that network error retries are handled by OpenAILike, not Workflow"""
|
"""Test that network errors are retried by the outer retry() wrapper"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_network_error_propagates_after_openai_retries(self, test_settings):
|
async def test_network_error_retried_by_outer_wrapper(self, test_settings):
|
||||||
"""Test that network errors are retried by OpenAILike and then propagate.
|
"""Test that network errors trigger the outer retry wrapper"""
|
||||||
|
|
||||||
Network retries are handled by OpenAILike (max_retries=3), not by our
|
|
||||||
StructuredOutputWorkflow. This test verifies that network errors propagate
|
|
||||||
up after OpenAILike exhausts its retries.
|
|
||||||
"""
|
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
|
||||||
):
|
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
|
|
||||||
# Simulate network error from acomplete (after OpenAILike retries exhausted)
|
|
||||||
network_error = ConnectionError("Connection refused")
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(side_effect=network_error)
|
|
||||||
|
|
||||||
# Network error wrapped in WorkflowRuntimeError
|
|
||||||
with pytest.raises(WorkflowRuntimeError, match="Connection refused"):
|
|
||||||
await llm.get_structured_response(
|
|
||||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
# acomplete called only once - network error propagates, not retried by Workflow
|
|
||||||
assert mock_settings.llm.acomplete.call_count == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_network_error_not_retried_by_workflow(self, test_settings):
|
|
||||||
"""Test that Workflow does NOT retry network errors (OpenAILike handles those).
|
|
||||||
|
|
||||||
This verifies the separation of concerns:
|
|
||||||
- StructuredOutputWorkflow: retries parse/validation errors
|
|
||||||
- OpenAILike: retries network errors (internally, max_retries=3)
|
|
||||||
"""
|
|
||||||
workflow = StructuredOutputWorkflow(
|
|
||||||
output_cls=TestResponse,
|
|
||||||
max_retries=3,
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
|
||||||
):
|
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
|
|
||||||
# Network error should propagate immediately, not trigger Workflow retry
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(
|
|
||||||
side_effect=TimeoutError("Request timed out")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Network error wrapped in WorkflowRuntimeError
|
|
||||||
with pytest.raises(WorkflowRuntimeError, match="Request timed out"):
|
|
||||||
await workflow.run(
|
|
||||||
prompt="Extract data",
|
|
||||||
texts=["Some text"],
|
|
||||||
tone_name=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only called once - Workflow doesn't retry network errors
|
|
||||||
assert mock_settings.llm.acomplete.call_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestWorkflowTimeoutRetry:
|
|
||||||
"""Test timeout retry mechanism in get_structured_response"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_timeout_retry_succeeds_on_retry(self, test_settings):
|
|
||||||
"""Test that WorkflowTimeoutError triggers retry and succeeds"""
|
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
call_count = {"count": 0}
|
call_count = {"count": 0}
|
||||||
|
|
||||||
async def workflow_run_side_effect(*args, **kwargs):
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||||
call_count["count"] += 1
|
call_count["count"] += 1
|
||||||
if call_count["count"] == 1:
|
if call_count["count"] == 1:
|
||||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
raise ConnectionError("Connection refused")
|
||||||
return {
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
"success": TestResponse(
|
|
||||||
title="Test", summary="Summary", confidence=0.95
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
with (
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
side_effect=astructured_predict_handler
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
|
||||||
):
|
|
||||||
mock_workflow = MagicMock()
|
|
||||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
|
||||||
mock_workflow_class.return_value = mock_workflow
|
|
||||||
|
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(
|
|
||||||
return_value=make_completion_response(
|
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await llm.get_structured_response(
|
result = await llm.get_structured_response(
|
||||||
@@ -402,36 +231,16 @@ class TestWorkflowTimeoutRetry:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.title == "Test"
|
assert result.title == "Test"
|
||||||
assert result.summary == "Summary"
|
|
||||||
assert call_count["count"] == 2
|
assert call_count["count"] == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_timeout_retry_exhausts_after_max_attempts(self, test_settings):
|
async def test_network_error_exhausts_retries(self, test_settings):
|
||||||
"""Test that timeout retry stops after max attempts"""
|
"""Test that persistent network errors exhaust retry attempts"""
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
call_count = {"count": 0}
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
async def workflow_run_side_effect(*args, **kwargs):
|
side_effect=ConnectionError("Connection refused")
|
||||||
call_count["count"] += 1
|
|
||||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
|
||||||
):
|
|
||||||
mock_workflow = MagicMock()
|
|
||||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
|
||||||
mock_workflow_class.return_value = mock_workflow
|
|
||||||
|
|
||||||
mock_summarizer = MagicMock()
|
|
||||||
mock_summarize.return_value = mock_summarizer
|
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
||||||
mock_settings.llm.acomplete = AsyncMock(
|
|
||||||
return_value=make_completion_response(
|
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
||||||
@@ -439,41 +248,129 @@ class TestWorkflowTimeoutRetry:
|
|||||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
assert call_count["count"] == 3
|
# 3 retry attempts
|
||||||
|
assert mock_settings.llm.astructured_predict.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextsInclusion:
|
||||||
|
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_timeout_retry_with_backoff(self, test_settings):
|
async def test_texts_included_in_prompt(self, test_settings):
|
||||||
"""Test that exponential backoff is applied between retries"""
|
"""Test that texts content is appended to the prompt for astructured_predict"""
|
||||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
call_times = []
|
captured_prompts = []
|
||||||
|
|
||||||
async def workflow_run_side_effect(*args, **kwargs):
|
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
|
||||||
call_times.append(monotonic())
|
captured_prompts.append(kwargs.get("user_prompt", ""))
|
||||||
if len(call_times) < 3:
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
|
||||||
return {
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
"success": TestResponse(
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
title="Test", summary="Summary", confidence=0.95
|
side_effect=capture_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
await llm.get_structured_response(
|
||||||
|
prompt="Identify all participants",
|
||||||
|
texts=["Alice: Hello everyone", "Bob: Hi Alice"],
|
||||||
|
output_cls=TestResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_prompts) == 1
|
||||||
|
prompt_sent = captured_prompts[0]
|
||||||
|
assert "Identify all participants" in prompt_sent
|
||||||
|
assert "Alice: Hello everyone" in prompt_sent
|
||||||
|
assert "Bob: Hi Alice" in prompt_sent
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_texts_uses_prompt_only(self, test_settings):
|
||||||
|
"""Test that empty texts list sends only the prompt"""
|
||||||
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
|
captured_prompts = []
|
||||||
|
|
||||||
|
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
|
||||||
|
captured_prompts.append(kwargs.get("user_prompt", ""))
|
||||||
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
|
|
||||||
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
|
side_effect=capture_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
await llm.get_structured_response(
|
||||||
|
prompt="Identify all participants",
|
||||||
|
texts=[],
|
||||||
|
output_cls=TestResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_prompts) == 1
|
||||||
|
assert captured_prompts[0] == "Identify all participants"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_texts_included_in_reflection_retry(self, test_settings):
|
||||||
|
"""Test that texts are included in the prompt even during reflection retries"""
|
||||||
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
|
captured_prompts = []
|
||||||
|
call_count = {"count": 0}
|
||||||
|
|
||||||
|
async def capture_and_fail_first(output_cls, prompt_tmpl, **kwargs):
|
||||||
|
call_count["count"] += 1
|
||||||
|
captured_prompts.append(kwargs.get("user_prompt", ""))
|
||||||
|
if call_count["count"] == 1:
|
||||||
|
raise ValidationError.from_exception_data(
|
||||||
|
title="TestResponse",
|
||||||
|
line_errors=[
|
||||||
|
{
|
||||||
|
"type": "missing",
|
||||||
|
"loc": ("summary",),
|
||||||
|
"msg": "Field required",
|
||||||
|
"input": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
}
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
|
|
||||||
with (
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
side_effect=capture_and_fail_first
|
||||||
patch("reflector.llm.Settings") as mock_settings,
|
)
|
||||||
):
|
|
||||||
mock_workflow = MagicMock()
|
|
||||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
|
||||||
mock_workflow_class.return_value = mock_workflow
|
|
||||||
|
|
||||||
mock_summarizer = MagicMock()
|
await llm.get_structured_response(
|
||||||
mock_summarize.return_value = mock_summarizer
|
prompt="Summarize this",
|
||||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
texts=["The meeting covered project updates"],
|
||||||
mock_settings.llm.acomplete = AsyncMock(
|
output_cls=TestResponse,
|
||||||
return_value=make_completion_response(
|
)
|
||||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
||||||
)
|
# Both first attempt and reflection retry should include the texts
|
||||||
|
assert len(captured_prompts) == 2
|
||||||
|
for prompt_sent in captured_prompts:
|
||||||
|
assert "Summarize this" in prompt_sent
|
||||||
|
assert "The meeting covered project updates" in prompt_sent
|
||||||
|
|
||||||
|
|
||||||
|
class TestReflectionRetryBackoff:
|
||||||
|
"""Test the reflection retry timing behavior"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_value_error_triggers_reflection(self, test_settings):
|
||||||
|
"""Test that ValueError (parse failure) also triggers reflection retry"""
|
||||||
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||||
|
|
||||||
|
call_count = {"count": 0}
|
||||||
|
|
||||||
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||||
|
call_count["count"] += 1
|
||||||
|
if call_count["count"] == 1:
|
||||||
|
raise ValueError("Could not parse output")
|
||||||
|
assert "reflection" in kwargs
|
||||||
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||||
|
|
||||||
|
with patch("reflector.llm.Settings") as mock_settings:
|
||||||
|
mock_settings.llm.astructured_predict = AsyncMock(
|
||||||
|
side_effect=astructured_predict_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await llm.get_structured_response(
|
result = await llm.get_structured_response(
|
||||||
@@ -481,8 +378,20 @@ class TestWorkflowTimeoutRetry:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.title == "Test"
|
assert result.title == "Test"
|
||||||
if len(call_times) >= 2:
|
assert call_count["count"] == 2
|
||||||
time_between_calls = call_times[1] - call_times[0]
|
|
||||||
assert (
|
@pytest.mark.asyncio
|
||||||
time_between_calls >= 1.5
|
async def test_format_validation_error_method(self, test_settings):
|
||||||
), f"Expected ~2s backoff, got {time_between_calls}s"
|
"""Test _format_validation_error produces correct feedback"""
|
||||||
|
# ValidationError
|
||||||
|
try:
|
||||||
|
TestResponse(title="x", summary="y", confidence=5.0) # confidence > 1
|
||||||
|
except ValidationError as e:
|
||||||
|
result = LLM._format_validation_error(e)
|
||||||
|
assert "Schema validation errors" in result
|
||||||
|
assert "confidence" in result
|
||||||
|
|
||||||
|
# ValueError
|
||||||
|
result = LLM._format_validation_error(ValueError("bad input"))
|
||||||
|
assert "Parse error:" in result
|
||||||
|
assert "bad input" in result
|
||||||
|
|||||||
Reference in New Issue
Block a user