mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
489 lines
20 KiB
Python
489 lines
20 KiB
Python
"""Tests for LLM parse error recovery using llama-index Workflow"""
|
|
|
|
from time import monotonic
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from pydantic import BaseModel, Field
|
|
from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError
|
|
|
|
from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow
|
|
from reflector.utils.retry import RetryException
|
|
|
|
|
|
class TestResponse(BaseModel):
|
|
"""Test response model for structured output"""
|
|
|
|
title: str = Field(description="A title")
|
|
summary: str = Field(description="A summary")
|
|
confidence: float = Field(description="Confidence score", ge=0, le=1)
|
|
|
|
|
|
def make_completion_response(text: str):
|
|
"""Create a mock CompletionResponse with .text attribute"""
|
|
response = MagicMock()
|
|
response.text = text
|
|
return response
|
|
|
|
|
|
class TestLLMParseErrorRecovery:
|
|
"""Test parse error recovery with Workflow feedback loop"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_error_recovery_with_feedback(self, test_settings):
|
|
"""Test that parse errors trigger retry with error feedback"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
# TreeSummarize returns plain text analysis (step 1)
|
|
mock_summarizer.aget_response = AsyncMock(
|
|
return_value="The analysis shows a test with summary and high confidence."
|
|
)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def acomplete_handler(prompt, *args, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
# First JSON formatting call returns invalid JSON
|
|
return make_completion_response('{"title": "Test"}')
|
|
else:
|
|
# Second call should have error feedback in prompt
|
|
assert "Your previous response could not be parsed:" in prompt
|
|
assert '{"title": "Test"}' in prompt
|
|
assert "Error:" in prompt
|
|
assert "Please try again" in prompt
|
|
return make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
)
|
|
|
|
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert result.summary == "Summary"
|
|
assert result.confidence == 0.95
|
|
# TreeSummarize called once, Settings.llm.acomplete called twice
|
|
assert mock_summarizer.aget_response.call_count == 1
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_parse_retry_attempts(self, test_settings):
|
|
"""Test that parse error retry stops after max attempts"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
# Always return invalid JSON from acomplete
|
|
mock_settings.llm.acomplete = AsyncMock(
|
|
return_value=make_completion_response(
|
|
'{"invalid": "missing required fields"}'
|
|
)
|
|
)
|
|
|
|
with pytest.raises(LLMParseError, match="Failed to parse"):
|
|
await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
|
|
# TreeSummarize called once, acomplete called max_retries times
|
|
assert mock_summarizer.aget_response.call_count == 1
|
|
assert mock_settings.llm.acomplete.call_count == expected_attempts
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
|
|
"""Test that raw response is logged when parse error occurs"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
caplog.at_level("ERROR"),
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def acomplete_handler(*args, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
return make_completion_response('{"title": "Test"}') # Invalid
|
|
return make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
)
|
|
|
|
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
|
|
error_logs = [r for r in caplog.records if r.levelname == "ERROR"]
|
|
raw_response_logged = any("Raw response:" in r.message for r in error_logs)
|
|
assert raw_response_logged, "Raw response should be logged on parse error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_validation_errors_in_feedback(self, test_settings):
|
|
"""Test that validation errors are included in feedback"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def acomplete_handler(prompt, *args, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
# Missing title and summary
|
|
return make_completion_response('{"confidence": 0.5}')
|
|
else:
|
|
# Should have schema validation errors in prompt
|
|
assert (
|
|
"Schema validation errors" in prompt
|
|
or "error" in prompt.lower()
|
|
)
|
|
return make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
)
|
|
|
|
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_on_first_attempt(self, test_settings):
|
|
"""Test that no retry happens when first attempt succeeds"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
mock_settings.llm.acomplete = AsyncMock(
|
|
return_value=make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
)
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert result.summary == "Summary"
|
|
assert result.confidence == 0.95
|
|
assert mock_summarizer.aget_response.call_count == 1
|
|
assert mock_settings.llm.acomplete.call_count == 1
|
|
|
|
|
|
class TestStructuredOutputWorkflow:
|
|
"""Direct tests for the StructuredOutputWorkflow"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workflow_retries_on_validation_error(self):
|
|
"""Test workflow retries when validation fails"""
|
|
workflow = StructuredOutputWorkflow(
|
|
output_cls=TestResponse,
|
|
max_retries=3,
|
|
timeout=30,
|
|
)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def acomplete_handler(*args, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] < 2:
|
|
return make_completion_response('{"title": "Only title"}')
|
|
return make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.9}'
|
|
)
|
|
|
|
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
|
|
|
|
result = await workflow.run(
|
|
prompt="Extract data",
|
|
texts=["Some text"],
|
|
tone_name=None,
|
|
)
|
|
|
|
assert "success" in result
|
|
assert result["success"].title == "Test"
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workflow_returns_error_after_max_retries(self):
|
|
"""Test workflow returns error after exhausting retries"""
|
|
workflow = StructuredOutputWorkflow(
|
|
output_cls=TestResponse,
|
|
max_retries=2,
|
|
timeout=30,
|
|
)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
# Always return invalid JSON
|
|
mock_settings.llm.acomplete = AsyncMock(
|
|
return_value=make_completion_response('{"invalid": true}')
|
|
)
|
|
|
|
result = await workflow.run(
|
|
prompt="Extract data",
|
|
texts=["Some text"],
|
|
tone_name=None,
|
|
)
|
|
|
|
assert "error" in result
|
|
# TreeSummarize called once, acomplete called max_retries times
|
|
assert mock_summarizer.aget_response.call_count == 1
|
|
assert mock_settings.llm.acomplete.call_count == 2
|
|
|
|
|
|
class TestNetworkErrorRetries:
|
|
"""Test that network error retries are handled by OpenAILike, not Workflow"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_network_error_propagates_after_openai_retries(self, test_settings):
|
|
"""Test that network errors are retried by OpenAILike and then propagate.
|
|
|
|
Network retries are handled by OpenAILike (max_retries=3), not by our
|
|
StructuredOutputWorkflow. This test verifies that network errors propagate
|
|
up after OpenAILike exhausts its retries.
|
|
"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
# Simulate network error from acomplete (after OpenAILike retries exhausted)
|
|
network_error = ConnectionError("Connection refused")
|
|
mock_settings.llm.acomplete = AsyncMock(side_effect=network_error)
|
|
|
|
# Network error wrapped in WorkflowRuntimeError
|
|
with pytest.raises(WorkflowRuntimeError, match="Connection refused"):
|
|
await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
# acomplete called only once - network error propagates, not retried by Workflow
|
|
assert mock_settings.llm.acomplete.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_network_error_not_retried_by_workflow(self, test_settings):
|
|
"""Test that Workflow does NOT retry network errors (OpenAILike handles those).
|
|
|
|
This verifies the separation of concerns:
|
|
- StructuredOutputWorkflow: retries parse/validation errors
|
|
- OpenAILike: retries network errors (internally, max_retries=3)
|
|
"""
|
|
workflow = StructuredOutputWorkflow(
|
|
output_cls=TestResponse,
|
|
max_retries=3,
|
|
timeout=30,
|
|
)
|
|
|
|
with (
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
|
|
# Network error should propagate immediately, not trigger Workflow retry
|
|
mock_settings.llm.acomplete = AsyncMock(
|
|
side_effect=TimeoutError("Request timed out")
|
|
)
|
|
|
|
# Network error wrapped in WorkflowRuntimeError
|
|
with pytest.raises(WorkflowRuntimeError, match="Request timed out"):
|
|
await workflow.run(
|
|
prompt="Extract data",
|
|
texts=["Some text"],
|
|
tone_name=None,
|
|
)
|
|
|
|
# Only called once - Workflow doesn't retry network errors
|
|
assert mock_settings.llm.acomplete.call_count == 1
|
|
|
|
|
|
class TestWorkflowTimeoutRetry:
|
|
"""Test timeout retry mechanism in get_structured_response"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_retry_succeeds_on_retry(self, test_settings):
|
|
"""Test that WorkflowTimeoutError triggers retry and succeeds"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def workflow_run_side_effect(*args, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
|
return {
|
|
"success": TestResponse(
|
|
title="Test", summary="Summary", confidence=0.95
|
|
)
|
|
}
|
|
|
|
with (
|
|
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_workflow = MagicMock()
|
|
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
|
mock_workflow_class.return_value = mock_workflow
|
|
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
mock_settings.llm.acomplete = AsyncMock(
|
|
return_value=make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
)
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert result.summary == "Summary"
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_retry_exhausts_after_max_attempts(self, test_settings):
|
|
"""Test that timeout retry stops after max attempts"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def workflow_run_side_effect(*args, **kwargs):
|
|
call_count["count"] += 1
|
|
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
|
|
|
with (
|
|
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_workflow = MagicMock()
|
|
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
|
mock_workflow_class.return_value = mock_workflow
|
|
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
mock_settings.llm.acomplete = AsyncMock(
|
|
return_value=make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
)
|
|
)
|
|
|
|
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
|
await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert call_count["count"] == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_retry_with_backoff(self, test_settings):
|
|
"""Test that exponential backoff is applied between retries"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_times = []
|
|
|
|
async def workflow_run_side_effect(*args, **kwargs):
|
|
call_times.append(monotonic())
|
|
if len(call_times) < 3:
|
|
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
|
return {
|
|
"success": TestResponse(
|
|
title="Test", summary="Summary", confidence=0.95
|
|
)
|
|
}
|
|
|
|
with (
|
|
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
|
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
):
|
|
mock_workflow = MagicMock()
|
|
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
|
mock_workflow_class.return_value = mock_workflow
|
|
|
|
mock_summarizer = MagicMock()
|
|
mock_summarize.return_value = mock_summarizer
|
|
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
|
mock_settings.llm.acomplete = AsyncMock(
|
|
return_value=make_completion_response(
|
|
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
|
)
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
if len(call_times) >= 2:
|
|
time_between_calls = call_times[1] - call_times[0]
|
|
assert (
|
|
time_between_calls >= 1.5
|
|
), f"Expected ~2s backoff, got {time_between_calls}s"
|