mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-19 03:36:55 +00:00
fix: switch structured output to tool-call with reflection retry (#879)
* fix: switch structured output to tool-call with reflection retry Replace the two-pass StructuredOutputWorkflow (TreeSummarize → acomplete) with astructured_predict + reflection retry loop for structured LLM output. - Enable function-calling mode (is_function_calling_model=True) - Use astructured_predict with PromptTemplate for first attempt - On ValidationError/parse failure, retry with reflection feedback - Add min_length=10 to TopicResponse title/summary fields - Remove dead StructuredOutputWorkflow class and its event types - Rewrite tests to match new astructured_predict approach * fix: include texts parameter in astructured_predict prompt The switch to astructured_predict dropped the texts parameter entirely, causing summary prompts (participants, subjects, action items) to be sent without the transcript content. Combine texts with the prompt before calling astructured_predict, mirroring what TreeSummarize did. * fix: reduce TopicResponse min_length from 10 to 8 for title and summary * ci: try fixing spawning job in github * ci: fix for new arm64 builder
This commit is contained in:
@@ -1,13 +1,11 @@
|
||||
"""Tests for LLM parse error recovery using llama-index Workflow"""
|
||||
"""Tests for LLM structured output with astructured_predict + reflection retry"""
|
||||
|
||||
from time import monotonic
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow
|
||||
from reflector.llm import LLM, LLMParseError
|
||||
from reflector.utils.retry import RetryException
|
||||
|
||||
|
||||
@@ -19,51 +17,43 @@ class TestResponse(BaseModel):
|
||||
confidence: float = Field(description="Confidence score", ge=0, le=1)
|
||||
|
||||
|
||||
def make_completion_response(text: str):
|
||||
"""Create a mock CompletionResponse with .text attribute"""
|
||||
response = MagicMock()
|
||||
response.text = text
|
||||
return response
|
||||
|
||||
|
||||
class TestLLMParseErrorRecovery:
|
||||
"""Test parse error recovery with Workflow feedback loop"""
|
||||
"""Test parse error recovery with astructured_predict reflection loop"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_error_recovery_with_feedback(self, test_settings):
|
||||
"""Test that parse errors trigger retry with error feedback"""
|
||||
"""Test that parse errors trigger retry with reflection prompt"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
# TreeSummarize returns plain text analysis (step 1)
|
||||
mock_summarizer.aget_response = AsyncMock(
|
||||
return_value="The analysis shows a test with summary and high confidence."
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
# First call: raise ValidationError (missing fields)
|
||||
raise ValidationError.from_exception_data(
|
||||
title="TestResponse",
|
||||
line_errors=[
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ("summary",),
|
||||
"msg": "Field required",
|
||||
"input": {"title": "Test"},
|
||||
}
|
||||
],
|
||||
)
|
||||
else:
|
||||
# Second call: should have reflection in the prompt
|
||||
assert "reflection" in kwargs
|
||||
assert "could not be parsed" in kwargs["reflection"]
|
||||
assert "Error:" in kwargs["reflection"]
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=astructured_predict_handler
|
||||
)
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def acomplete_handler(prompt, *args, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
# First JSON formatting call returns invalid JSON
|
||||
return make_completion_response('{"title": "Test"}')
|
||||
else:
|
||||
# Second call should have error feedback in prompt
|
||||
assert "Your previous response could not be parsed:" in prompt
|
||||
assert '{"title": "Test"}' in prompt
|
||||
assert "Error:" in prompt
|
||||
assert "Please try again" in prompt
|
||||
return make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
|
||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
||||
|
||||
result = await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
)
|
||||
@@ -71,8 +61,6 @@ class TestLLMParseErrorRecovery:
|
||||
assert result.title == "Test"
|
||||
assert result.summary == "Summary"
|
||||
assert result.confidence == 0.95
|
||||
# TreeSummarize called once, Settings.llm.acomplete called twice
|
||||
assert mock_summarizer.aget_response.call_count == 1
|
||||
assert call_count["count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -80,56 +68,61 @@ class TestLLMParseErrorRecovery:
|
||||
"""Test that parse error retry stops after max attempts"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
|
||||
# Always return invalid JSON from acomplete
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"invalid": "missing required fields"}'
|
||||
)
|
||||
# Always raise ValidationError
|
||||
async def always_fail(output_cls, prompt_tmpl, **kwargs):
|
||||
raise ValidationError.from_exception_data(
|
||||
title="TestResponse",
|
||||
line_errors=[
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ("summary",),
|
||||
"msg": "Field required",
|
||||
"input": {},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(side_effect=always_fail)
|
||||
|
||||
with pytest.raises(LLMParseError, match="Failed to parse"):
|
||||
await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
)
|
||||
|
||||
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
|
||||
# TreeSummarize called once, acomplete called max_retries times
|
||||
assert mock_summarizer.aget_response.call_count == 1
|
||||
assert mock_settings.llm.acomplete.call_count == expected_attempts
|
||||
assert mock_settings.llm.astructured_predict.call_count == expected_attempts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
|
||||
"""Test that raw response is logged when parse error occurs"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
raise ValidationError.from_exception_data(
|
||||
title="TestResponse",
|
||||
line_errors=[
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ("summary",),
|
||||
"msg": "Field required",
|
||||
"input": {"title": "Test"},
|
||||
}
|
||||
],
|
||||
)
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
caplog.at_level("ERROR"),
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def acomplete_handler(*args, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
return make_completion_response('{"title": "Test"}') # Invalid
|
||||
return make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
|
||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=astructured_predict_handler
|
||||
)
|
||||
|
||||
result = await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
@@ -143,35 +136,45 @@ class TestLLMParseErrorRecovery:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_validation_errors_in_feedback(self, test_settings):
|
||||
"""Test that validation errors are included in feedback"""
|
||||
"""Test that validation errors are included in reflection feedback"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
call_count = {"count": 0}
|
||||
|
||||
call_count = {"count": 0}
|
||||
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
# Missing title and summary
|
||||
raise ValidationError.from_exception_data(
|
||||
title="TestResponse",
|
||||
line_errors=[
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ("title",),
|
||||
"msg": "Field required",
|
||||
"input": {},
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ("summary",),
|
||||
"msg": "Field required",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
)
|
||||
else:
|
||||
# Should have schema validation errors in reflection
|
||||
assert "reflection" in kwargs
|
||||
assert (
|
||||
"Schema validation errors" in kwargs["reflection"]
|
||||
or "error" in kwargs["reflection"].lower()
|
||||
)
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
async def acomplete_handler(prompt, *args, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
# Missing title and summary
|
||||
return make_completion_response('{"confidence": 0.5}')
|
||||
else:
|
||||
# Should have schema validation errors in prompt
|
||||
assert (
|
||||
"Schema validation errors" in prompt
|
||||
or "error" in prompt.lower()
|
||||
)
|
||||
return make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
|
||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=astructured_predict_handler
|
||||
)
|
||||
|
||||
result = await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
@@ -185,17 +188,10 @@ class TestLLMParseErrorRecovery:
|
||||
"""Test that no retry happens when first attempt succeeds"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
return_value=TestResponse(
|
||||
title="Test", summary="Summary", confidence=0.95
|
||||
)
|
||||
)
|
||||
|
||||
@@ -206,195 +202,28 @@ class TestLLMParseErrorRecovery:
|
||||
assert result.title == "Test"
|
||||
assert result.summary == "Summary"
|
||||
assert result.confidence == 0.95
|
||||
assert mock_summarizer.aget_response.call_count == 1
|
||||
assert mock_settings.llm.acomplete.call_count == 1
|
||||
|
||||
|
||||
class TestStructuredOutputWorkflow:
|
||||
"""Direct tests for the StructuredOutputWorkflow"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_retries_on_validation_error(self):
|
||||
"""Test workflow retries when validation fails"""
|
||||
workflow = StructuredOutputWorkflow(
|
||||
output_cls=TestResponse,
|
||||
max_retries=3,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def acomplete_handler(*args, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] < 2:
|
||||
return make_completion_response('{"title": "Only title"}')
|
||||
return make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.9}'
|
||||
)
|
||||
|
||||
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
||||
|
||||
result = await workflow.run(
|
||||
prompt="Extract data",
|
||||
texts=["Some text"],
|
||||
tone_name=None,
|
||||
)
|
||||
|
||||
assert "success" in result
|
||||
assert result["success"].title == "Test"
|
||||
assert call_count["count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_returns_error_after_max_retries(self):
|
||||
"""Test workflow returns error after exhausting retries"""
|
||||
workflow = StructuredOutputWorkflow(
|
||||
output_cls=TestResponse,
|
||||
max_retries=2,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
|
||||
# Always return invalid JSON
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response('{"invalid": true}')
|
||||
)
|
||||
|
||||
result = await workflow.run(
|
||||
prompt="Extract data",
|
||||
texts=["Some text"],
|
||||
tone_name=None,
|
||||
)
|
||||
|
||||
assert "error" in result
|
||||
# TreeSummarize called once, acomplete called max_retries times
|
||||
assert mock_summarizer.aget_response.call_count == 1
|
||||
assert mock_settings.llm.acomplete.call_count == 2
|
||||
assert mock_settings.llm.astructured_predict.call_count == 1
|
||||
|
||||
|
||||
class TestNetworkErrorRetries:
|
||||
"""Test that network error retries are handled by OpenAILike, not Workflow"""
|
||||
"""Test that network errors are retried by the outer retry() wrapper"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_error_propagates_after_openai_retries(self, test_settings):
|
||||
"""Test that network errors are retried by OpenAILike and then propagate.
|
||||
|
||||
Network retries are handled by OpenAILike (max_retries=3), not by our
|
||||
StructuredOutputWorkflow. This test verifies that network errors propagate
|
||||
up after OpenAILike exhausts its retries.
|
||||
"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
|
||||
# Simulate network error from acomplete (after OpenAILike retries exhausted)
|
||||
network_error = ConnectionError("Connection refused")
|
||||
mock_settings.llm.acomplete = AsyncMock(side_effect=network_error)
|
||||
|
||||
# Network error wrapped in WorkflowRuntimeError
|
||||
with pytest.raises(WorkflowRuntimeError, match="Connection refused"):
|
||||
await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
)
|
||||
|
||||
# acomplete called only once - network error propagates, not retried by Workflow
|
||||
assert mock_settings.llm.acomplete.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_error_not_retried_by_workflow(self, test_settings):
|
||||
"""Test that Workflow does NOT retry network errors (OpenAILike handles those).
|
||||
|
||||
This verifies the separation of concerns:
|
||||
- StructuredOutputWorkflow: retries parse/validation errors
|
||||
- OpenAILike: retries network errors (internally, max_retries=3)
|
||||
"""
|
||||
workflow = StructuredOutputWorkflow(
|
||||
output_cls=TestResponse,
|
||||
max_retries=3,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
|
||||
# Network error should propagate immediately, not trigger Workflow retry
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
side_effect=TimeoutError("Request timed out")
|
||||
)
|
||||
|
||||
# Network error wrapped in WorkflowRuntimeError
|
||||
with pytest.raises(WorkflowRuntimeError, match="Request timed out"):
|
||||
await workflow.run(
|
||||
prompt="Extract data",
|
||||
texts=["Some text"],
|
||||
tone_name=None,
|
||||
)
|
||||
|
||||
# Only called once - Workflow doesn't retry network errors
|
||||
assert mock_settings.llm.acomplete.call_count == 1
|
||||
|
||||
|
||||
class TestWorkflowTimeoutRetry:
|
||||
"""Test timeout retry mechanism in get_structured_response"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_retry_succeeds_on_retry(self, test_settings):
|
||||
"""Test that WorkflowTimeoutError triggers retry and succeeds"""
|
||||
async def test_network_error_retried_by_outer_wrapper(self, test_settings):
|
||||
"""Test that network errors trigger the outer retry wrapper"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def workflow_run_side_effect(*args, **kwargs):
|
||||
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
||||
return {
|
||||
"success": TestResponse(
|
||||
title="Test", summary="Summary", confidence=0.95
|
||||
)
|
||||
}
|
||||
raise ConnectionError("Connection refused")
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
||||
mock_workflow_class.return_value = mock_workflow
|
||||
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=astructured_predict_handler
|
||||
)
|
||||
|
||||
result = await llm.get_structured_response(
|
||||
@@ -402,36 +231,16 @@ class TestWorkflowTimeoutRetry:
|
||||
)
|
||||
|
||||
assert result.title == "Test"
|
||||
assert result.summary == "Summary"
|
||||
assert call_count["count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_retry_exhausts_after_max_attempts(self, test_settings):
|
||||
"""Test that timeout retry stops after max attempts"""
|
||||
async def test_network_error_exhausts_retries(self, test_settings):
|
||||
"""Test that persistent network errors exhaust retry attempts"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def workflow_run_side_effect(*args, **kwargs):
|
||||
call_count["count"] += 1
|
||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
||||
|
||||
with (
|
||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
||||
mock_workflow_class.return_value = mock_workflow
|
||||
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=ConnectionError("Connection refused")
|
||||
)
|
||||
|
||||
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
||||
@@ -439,41 +248,129 @@ class TestWorkflowTimeoutRetry:
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
)
|
||||
|
||||
assert call_count["count"] == 3
|
||||
# 3 retry attempts
|
||||
assert mock_settings.llm.astructured_predict.call_count == 3
|
||||
|
||||
|
||||
class TestTextsInclusion:
|
||||
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_retry_with_backoff(self, test_settings):
|
||||
"""Test that exponential backoff is applied between retries"""
|
||||
async def test_texts_included_in_prompt(self, test_settings):
|
||||
"""Test that texts content is appended to the prompt for astructured_predict"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_times = []
|
||||
captured_prompts = []
|
||||
|
||||
async def workflow_run_side_effect(*args, **kwargs):
|
||||
call_times.append(monotonic())
|
||||
if len(call_times) < 3:
|
||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
||||
return {
|
||||
"success": TestResponse(
|
||||
title="Test", summary="Summary", confidence=0.95
|
||||
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
|
||||
captured_prompts.append(kwargs.get("user_prompt", ""))
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=capture_prompt
|
||||
)
|
||||
|
||||
await llm.get_structured_response(
|
||||
prompt="Identify all participants",
|
||||
texts=["Alice: Hello everyone", "Bob: Hi Alice"],
|
||||
output_cls=TestResponse,
|
||||
)
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
prompt_sent = captured_prompts[0]
|
||||
assert "Identify all participants" in prompt_sent
|
||||
assert "Alice: Hello everyone" in prompt_sent
|
||||
assert "Bob: Hi Alice" in prompt_sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_texts_uses_prompt_only(self, test_settings):
|
||||
"""Test that empty texts list sends only the prompt"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
captured_prompts = []
|
||||
|
||||
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
|
||||
captured_prompts.append(kwargs.get("user_prompt", ""))
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=capture_prompt
|
||||
)
|
||||
|
||||
await llm.get_structured_response(
|
||||
prompt="Identify all participants",
|
||||
texts=[],
|
||||
output_cls=TestResponse,
|
||||
)
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert captured_prompts[0] == "Identify all participants"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_texts_included_in_reflection_retry(self, test_settings):
|
||||
"""Test that texts are included in the prompt even during reflection retries"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
captured_prompts = []
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def capture_and_fail_first(output_cls, prompt_tmpl, **kwargs):
|
||||
call_count["count"] += 1
|
||||
captured_prompts.append(kwargs.get("user_prompt", ""))
|
||||
if call_count["count"] == 1:
|
||||
raise ValidationError.from_exception_data(
|
||||
title="TestResponse",
|
||||
line_errors=[
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ("summary",),
|
||||
"msg": "Field required",
|
||||
"input": {},
|
||||
}
|
||||
],
|
||||
)
|
||||
}
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
||||
mock_workflow_class.return_value = mock_workflow
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=capture_and_fail_first
|
||||
)
|
||||
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
await llm.get_structured_response(
|
||||
prompt="Summarize this",
|
||||
texts=["The meeting covered project updates"],
|
||||
output_cls=TestResponse,
|
||||
)
|
||||
|
||||
# Both first attempt and reflection retry should include the texts
|
||||
assert len(captured_prompts) == 2
|
||||
for prompt_sent in captured_prompts:
|
||||
assert "Summarize this" in prompt_sent
|
||||
assert "The meeting covered project updates" in prompt_sent
|
||||
|
||||
|
||||
class TestReflectionRetryBackoff:
|
||||
"""Test the reflection retry timing behavior"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_error_triggers_reflection(self, test_settings):
|
||||
"""Test that ValueError (parse failure) also triggers reflection retry"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
raise ValueError("Could not parse output")
|
||||
assert "reflection" in kwargs
|
||||
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
||||
|
||||
with patch("reflector.llm.Settings") as mock_settings:
|
||||
mock_settings.llm.astructured_predict = AsyncMock(
|
||||
side_effect=astructured_predict_handler
|
||||
)
|
||||
|
||||
result = await llm.get_structured_response(
|
||||
@@ -481,8 +378,20 @@ class TestWorkflowTimeoutRetry:
|
||||
)
|
||||
|
||||
assert result.title == "Test"
|
||||
if len(call_times) >= 2:
|
||||
time_between_calls = call_times[1] - call_times[0]
|
||||
assert (
|
||||
time_between_calls >= 1.5
|
||||
), f"Expected ~2s backoff, got {time_between_calls}s"
|
||||
assert call_count["count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_validation_error_method(self, test_settings):
|
||||
"""Test _format_validation_error produces correct feedback"""
|
||||
# ValidationError
|
||||
try:
|
||||
TestResponse(title="x", summary="y", confidence=5.0) # confidence > 1
|
||||
except ValidationError as e:
|
||||
result = LLM._format_validation_error(e)
|
||||
assert "Schema validation errors" in result
|
||||
assert "confidence" in result
|
||||
|
||||
# ValueError
|
||||
result = LLM._format_validation_error(ValueError("bad input"))
|
||||
assert "Parse error:" in result
|
||||
assert "bad input" in result
|
||||
|
||||
Reference in New Issue
Block a user