Files
reflector/server/tests/test_llm_retry.py
Mathieu Virbel 5d547586ef fix: switch structured output to tool-call with reflection retry (#879)
* fix: switch structured output to tool-call with reflection retry

Replace the two-pass StructuredOutputWorkflow (TreeSummarize → acomplete)
with astructured_predict + reflection retry loop for structured LLM output.

- Enable function-calling mode (is_function_calling_model=True)
- Use astructured_predict with PromptTemplate for first attempt
- On ValidationError/parse failure, retry with reflection feedback
- Add min_length=10 to TopicResponse title/summary fields
- Remove dead StructuredOutputWorkflow class and its event types
- Rewrite tests to match new astructured_predict approach

* fix: include texts parameter in astructured_predict prompt

The switch to astructured_predict dropped the texts parameter entirely,
causing summary prompts (participants, subjects, action items) to be
sent without the transcript content. Combine texts with the prompt
before calling astructured_predict, mirroring what TreeSummarize did.

* fix: reduce TopicResponse min_length from 10 to 8 for title and summary

* ci: try fixing spawning job in github

* ci: fix for new arm64 builder
2026-02-24 18:28:11 -06:00

398 lines
16 KiB
Python

"""Tests for LLM structured output with astructured_predict + reflection retry"""
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import BaseModel, Field, ValidationError
from reflector.llm import LLM, LLMParseError
from reflector.utils.retry import RetryException
class TestResponse(BaseModel):
"""Test response model for structured output"""
title: str = Field(description="A title")
summary: str = Field(description="A summary")
confidence: float = Field(description="Confidence score", ge=0, le=1)
class TestLLMParseErrorRecovery:
"""Test parse error recovery with astructured_predict reflection loop"""
@pytest.mark.asyncio
async def test_parse_error_recovery_with_feedback(self, test_settings):
"""Test that parse errors trigger retry with reflection prompt"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# First call: raise ValidationError (missing fields)
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {"title": "Test"},
}
],
)
else:
# Second call: should have reflection in the prompt
assert "reflection" in kwargs
assert "could not be parsed" in kwargs["reflection"]
assert "Error:" in kwargs["reflection"]
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert result.summary == "Summary"
assert result.confidence == 0.95
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_max_parse_retry_attempts(self, test_settings):
"""Test that parse error retry stops after max attempts"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
# Always raise ValidationError
async def always_fail(output_cls, prompt_tmpl, **kwargs):
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {},
}
],
)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(side_effect=always_fail)
with pytest.raises(LLMParseError, match="Failed to parse"):
await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
assert mock_settings.llm.astructured_predict.call_count == expected_attempts
@pytest.mark.asyncio
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
"""Test that raw response is logged when parse error occurs"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {"title": "Test"},
}
],
)
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with (
patch("reflector.llm.Settings") as mock_settings,
caplog.at_level("ERROR"),
):
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
error_logs = [r for r in caplog.records if r.levelname == "ERROR"]
raw_response_logged = any("Raw response:" in r.message for r in error_logs)
assert raw_response_logged, "Raw response should be logged on parse error"
@pytest.mark.asyncio
async def test_multiple_validation_errors_in_feedback(self, test_settings):
"""Test that validation errors are included in reflection feedback"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# Missing title and summary
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("title",),
"msg": "Field required",
"input": {},
},
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {},
},
],
)
else:
# Should have schema validation errors in reflection
assert "reflection" in kwargs
assert (
"Schema validation errors" in kwargs["reflection"]
or "error" in kwargs["reflection"].lower()
)
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_success_on_first_attempt(self, test_settings):
"""Test that no retry happens when first attempt succeeds"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
return_value=TestResponse(
title="Test", summary="Summary", confidence=0.95
)
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert result.summary == "Summary"
assert result.confidence == 0.95
assert mock_settings.llm.astructured_predict.call_count == 1
class TestNetworkErrorRetries:
"""Test that network errors are retried by the outer retry() wrapper"""
@pytest.mark.asyncio
async def test_network_error_retried_by_outer_wrapper(self, test_settings):
"""Test that network errors trigger the outer retry wrapper"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
raise ConnectionError("Connection refused")
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_network_error_exhausts_retries(self, test_settings):
"""Test that persistent network errors exhaust retry attempts"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=ConnectionError("Connection refused")
)
with pytest.raises(RetryException, match="Retry attempts exceeded"):
await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
# 3 retry attempts
assert mock_settings.llm.astructured_predict.call_count == 3
class TestTextsInclusion:
"""Test that texts parameter is included in the prompt sent to astructured_predict"""
@pytest.mark.asyncio
async def test_texts_included_in_prompt(self, test_settings):
"""Test that texts content is appended to the prompt for astructured_predict"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
captured_prompts = []
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
captured_prompts.append(kwargs.get("user_prompt", ""))
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=capture_prompt
)
await llm.get_structured_response(
prompt="Identify all participants",
texts=["Alice: Hello everyone", "Bob: Hi Alice"],
output_cls=TestResponse,
)
assert len(captured_prompts) == 1
prompt_sent = captured_prompts[0]
assert "Identify all participants" in prompt_sent
assert "Alice: Hello everyone" in prompt_sent
assert "Bob: Hi Alice" in prompt_sent
@pytest.mark.asyncio
async def test_empty_texts_uses_prompt_only(self, test_settings):
"""Test that empty texts list sends only the prompt"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
captured_prompts = []
async def capture_prompt(output_cls, prompt_tmpl, **kwargs):
captured_prompts.append(kwargs.get("user_prompt", ""))
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=capture_prompt
)
await llm.get_structured_response(
prompt="Identify all participants",
texts=[],
output_cls=TestResponse,
)
assert len(captured_prompts) == 1
assert captured_prompts[0] == "Identify all participants"
@pytest.mark.asyncio
async def test_texts_included_in_reflection_retry(self, test_settings):
"""Test that texts are included in the prompt even during reflection retries"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
captured_prompts = []
call_count = {"count": 0}
async def capture_and_fail_first(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
captured_prompts.append(kwargs.get("user_prompt", ""))
if call_count["count"] == 1:
raise ValidationError.from_exception_data(
title="TestResponse",
line_errors=[
{
"type": "missing",
"loc": ("summary",),
"msg": "Field required",
"input": {},
}
],
)
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=capture_and_fail_first
)
await llm.get_structured_response(
prompt="Summarize this",
texts=["The meeting covered project updates"],
output_cls=TestResponse,
)
# Both first attempt and reflection retry should include the texts
assert len(captured_prompts) == 2
for prompt_sent in captured_prompts:
assert "Summarize this" in prompt_sent
assert "The meeting covered project updates" in prompt_sent
class TestReflectionRetryBackoff:
"""Test the reflection retry timing behavior"""
@pytest.mark.asyncio
async def test_value_error_triggers_reflection(self, test_settings):
"""Test that ValueError (parse failure) also triggers reflection retry"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
call_count = {"count": 0}
async def astructured_predict_handler(output_cls, prompt_tmpl, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
raise ValueError("Could not parse output")
assert "reflection" in kwargs
return TestResponse(title="Test", summary="Summary", confidence=0.95)
with patch("reflector.llm.Settings") as mock_settings:
mock_settings.llm.astructured_predict = AsyncMock(
side_effect=astructured_predict_handler
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_format_validation_error_method(self, test_settings):
"""Test _format_validation_error produces correct feedback"""
# ValidationError
try:
TestResponse(title="x", summary="y", confidence=5.0) # confidence > 1
except ValidationError as e:
result = LLM._format_validation_error(e)
assert "Schema validation errors" in result
assert "confidence" in result
# ValueError
result = LLM._format_validation_error(ValueError("bad input"))
assert "Parse error:" in result
assert "bad input" in result