mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
* Increase max connections * Classify hard and transient hatchet errors * Fan out partial success * Force reprocessing of error transcripts * Stop retrying on 402 payment required * Avoid httpx/hatchet timeout race * Add retry wrapper to get_response for for transient errors * Add retry backoff * Return falsy results so get_response won't retry on empty string * Skip error status in on_workflow_failure when transcript already ended * Fix precommit issues * Fail step on first fan-out failure instead of skipping
455 lines
18 KiB
Python
455 lines
18 KiB
Python
"""Tests for LLM structured output with astructured_predict + reflection retry"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
from reflector.llm import LLM, LLMParseError
|
|
from reflector.utils.retry import RetryException
|
|
|
|
|
|
class TestResponse(BaseModel):
|
|
"""Test response model for structured output"""
|
|
|
|
title: str = Field(description="A title")
|
|
summary: str = Field(description="A summary")
|
|
confidence: float = Field(description="Confidence score", ge=0, le=1)
|
|
|
|
|
|
class TestLLMParseErrorRecovery:
|
|
"""Test parse error recovery with astructured_predict reflection loop"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_error_recovery_with_feedback(self, test_settings):
|
|
"""Test that parse errors trigger retry with reflection prompt"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
# First call: raise ValidationError (missing fields)
|
|
raise ValidationError.from_exception_data(
|
|
title="TestResponse",
|
|
line_errors=[
|
|
{
|
|
"type": "missing",
|
|
"loc": ("summary",),
|
|
"msg": "Field required",
|
|
"input": {"title": "Test"},
|
|
}
|
|
],
|
|
)
|
|
else:
|
|
# Second call: should have reflection in the prompt
|
|
assert "reflection" in kwargs
|
|
assert "could not be parsed" in kwargs["reflection"]
|
|
assert "Error:" in kwargs["reflection"]
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=astructured_predict_handler
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert result.summary == "Summary"
|
|
assert result.confidence == 0.95
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_parse_retry_attempts(self, test_settings):
|
|
"""Test that parse error retry stops after max attempts"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
# Always raise ValidationError
|
|
async def always_fail(output_cls, prompt_tmpl, **kwargs):
|
|
raise ValidationError.from_exception_data(
|
|
title="TestResponse",
|
|
line_errors=[
|
|
{
|
|
"type": "missing",
|
|
"loc": ("summary",),
|
|
"msg": "Field required",
|
|
"input": {},
|
|
}
|
|
],
|
|
)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(side_effect=always_fail)
|
|
|
|
with pytest.raises(LLMParseError, match="Failed to parse"):
|
|
await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
|
|
assert mock_settings.llm.astructured_predict.call_count == expected_attempts
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
|
|
"""Test that raw response is logged when parse error occurs"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
raise ValidationError.from_exception_data(
|
|
title="TestResponse",
|
|
line_errors=[
|
|
{
|
|
"type": "missing",
|
|
"loc": ("summary",),
|
|
"msg": "Field required",
|
|
"input": {"title": "Test"},
|
|
}
|
|
],
|
|
)
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with (
|
|
patch("reflector.llm.Settings") as mock_settings,
|
|
caplog.at_level("ERROR"),
|
|
):
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=astructured_predict_handler
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
|
|
error_logs = [r for r in caplog.records if r.levelname == "ERROR"]
|
|
raw_response_logged = any("Raw response:" in r.message for r in error_logs)
|
|
assert raw_response_logged, "Raw response should be logged on parse error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_validation_errors_in_feedback(self, test_settings):
|
|
"""Test that validation errors are included in reflection feedback"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
# Missing title and summary
|
|
raise ValidationError.from_exception_data(
|
|
title="TestResponse",
|
|
line_errors=[
|
|
{
|
|
"type": "missing",
|
|
"loc": ("title",),
|
|
"msg": "Field required",
|
|
"input": {},
|
|
},
|
|
{
|
|
"type": "missing",
|
|
"loc": ("summary",),
|
|
"msg": "Field required",
|
|
"input": {},
|
|
},
|
|
],
|
|
)
|
|
else:
|
|
# Should have schema validation errors in reflection
|
|
assert "reflection" in kwargs
|
|
assert (
|
|
"Schema validation errors" in kwargs["reflection"]
|
|
or "error" in kwargs["reflection"].lower()
|
|
)
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=astructured_predict_handler
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_on_first_attempt(self, test_settings):
|
|
"""Test that no retry happens when first attempt succeeds"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
return_value=TestResponse(
|
|
title="Test", summary="Summary", confidence=0.95
|
|
)
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert result.summary == "Summary"
|
|
assert result.confidence == 0.95
|
|
assert mock_settings.llm.astructured_predict.call_count == 1
|
|
|
|
|
|
class TestNetworkErrorRetries:
|
|
"""Test that network errors are retried by the outer retry() wrapper"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_network_error_retried_by_outer_wrapper(self, test_settings):
|
|
"""Test that network errors trigger the outer retry wrapper"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
raise ConnectionError("Connection refused")
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=astructured_predict_handler
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_network_error_exhausts_retries(self, test_settings):
|
|
"""Test that persistent network errors exhaust retry attempts"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=ConnectionError("Connection refused")
|
|
)
|
|
|
|
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
|
await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
# 3 retry attempts
|
|
assert mock_settings.llm.astructured_predict.call_count == 3
|
|
|
|
|
|
class TestGetResponseRetries:
|
|
"""Test that get_response() uses the same retry() wrapper for transient errors."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_response_retries_on_connection_error(self, test_settings):
|
|
"""Test that get_response retries on ConnectionError and returns on success."""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
mock_instance = MagicMock()
|
|
mock_instance.aget_response = AsyncMock(
|
|
side_effect=[
|
|
ConnectionError("Connection refused"),
|
|
" Summary text ",
|
|
]
|
|
)
|
|
|
|
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
|
result = await llm.get_response("Prompt", ["text"])
|
|
|
|
assert result == "Summary text"
|
|
assert mock_instance.aget_response.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_response_exhausts_retries(self, test_settings):
|
|
"""Test that get_response raises RetryException after retry attempts exceeded."""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
mock_instance = MagicMock()
|
|
mock_instance.aget_response = AsyncMock(
|
|
side_effect=ConnectionError("Connection refused")
|
|
)
|
|
|
|
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
|
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
|
await llm.get_response("Prompt", ["text"])
|
|
|
|
assert mock_instance.aget_response.call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_response_returns_empty_string_without_retry(self, test_settings):
|
|
"""Empty or whitespace-only LLM response must return '' and not raise RetryException.
|
|
|
|
retry() must return falsy results (e.g. '' from get_response) instead of
|
|
treating them as 'no result' and retrying until RetryException.
|
|
"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
mock_instance = MagicMock()
|
|
mock_instance.aget_response = AsyncMock(return_value=" \n ") # strip() -> ""
|
|
|
|
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
|
|
result = await llm.get_response("Prompt", ["text"])
|
|
|
|
assert result == ""
|
|
assert mock_instance.aget_response.call_count == 1
|
|
|
|
|
|
class TestTextsInclusion:
|
|
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_texts_included_in_prompt(self, test_settings):
|
|
"""Test that texts content is appended to the prompt for astructured_predict"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
captured_prompts = []
|
|
|
|
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
|
|
captured_prompts.append(kwargs.get("user_prompt", ""))
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=capture_prompt
|
|
)
|
|
|
|
await llm.get_structured_response(
|
|
prompt="Identify all participants",
|
|
texts=["Alice: Hello everyone", "Bob: Hi Alice"],
|
|
output_cls=TestResponse,
|
|
)
|
|
|
|
assert len(captured_prompts) == 1
|
|
prompt_sent = captured_prompts[0]
|
|
assert "Identify all participants" in prompt_sent
|
|
assert "Alice: Hello everyone" in prompt_sent
|
|
assert "Bob: Hi Alice" in prompt_sent
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_texts_uses_prompt_only(self, test_settings):
|
|
"""Test that empty texts list sends only the prompt"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
captured_prompts = []
|
|
|
|
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
|
|
captured_prompts.append(kwargs.get("user_prompt", ""))
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=capture_prompt
|
|
)
|
|
|
|
await llm.get_structured_response(
|
|
prompt="Identify all participants",
|
|
texts=[],
|
|
output_cls=TestResponse,
|
|
)
|
|
|
|
assert len(captured_prompts) == 1
|
|
assert captured_prompts[0] == "Identify all participants"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_texts_included_in_reflection_retry(self, test_settings):
|
|
"""Test that texts are included in the prompt even during reflection retries"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
captured_prompts = []
|
|
call_count = {"count": 0}
|
|
|
|
async def capture_and_fail_first(output_cls, prompt_tmpl, **kwargs):
|
|
call_count["count"] += 1
|
|
captured_prompts.append(kwargs.get("user_prompt", ""))
|
|
if call_count["count"] == 1:
|
|
raise ValidationError.from_exception_data(
|
|
title="TestResponse",
|
|
line_errors=[
|
|
{
|
|
"type": "missing",
|
|
"loc": ("summary",),
|
|
"msg": "Field required",
|
|
"input": {},
|
|
}
|
|
],
|
|
)
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=capture_and_fail_first
|
|
)
|
|
|
|
await llm.get_structured_response(
|
|
prompt="Summarize this",
|
|
texts=["The meeting covered project updates"],
|
|
output_cls=TestResponse,
|
|
)
|
|
|
|
# Both first attempt and reflection retry should include the texts
|
|
assert len(captured_prompts) == 2
|
|
for prompt_sent in captured_prompts:
|
|
assert "Summarize this" in prompt_sent
|
|
assert "The meeting covered project updates" in prompt_sent
|
|
|
|
|
|
class TestReflectionRetryBackoff:
|
|
"""Test the reflection retry timing behavior"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_value_error_triggers_reflection(self, test_settings):
|
|
"""Test that ValueError (parse failure) also triggers reflection retry"""
|
|
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
|
|
|
call_count = {"count": 0}
|
|
|
|
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
|
|
call_count["count"] += 1
|
|
if call_count["count"] == 1:
|
|
raise ValueError("Could not parse output")
|
|
assert "reflection" in kwargs
|
|
return TestResponse(title="Test", summary="Summary", confidence=0.95)
|
|
|
|
with patch("reflector.llm.Settings") as mock_settings:
|
|
mock_settings.llm.astructured_predict = AsyncMock(
|
|
side_effect=astructured_predict_handler
|
|
)
|
|
|
|
result = await llm.get_structured_response(
|
|
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
|
)
|
|
|
|
assert result.title == "Test"
|
|
assert call_count["count"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_format_validation_error_method(self, test_settings):
|
|
"""Test _format_validation_error produces correct feedback"""
|
|
# ValidationError
|
|
try:
|
|
TestResponse(title="x", summary="y", confidence=5.0) # confidence > 1
|
|
except ValidationError as e:
|
|
result = LLM._format_validation_error(e)
|
|
assert "Schema validation errors" in result
|
|
assert "confidence" in result
|
|
|
|
# ValueError
|
|
result = LLM._format_validation_error(ValueError("bad input"))
|
|
assert "Parse error:" in result
|
|
assert "bad input" in result
|