mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
fix: retry on workflow timeout (#798)
This commit is contained in:
@@ -16,6 +16,9 @@ from llama_index.core.workflow import (
|
||||
)
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from workflows.errors import WorkflowTimeoutError
|
||||
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
OutputT = TypeVar("OutputT", bound=BaseModel)
|
||||
@@ -231,24 +234,33 @@ class LLM:
|
||||
tone_name: str | None = None,
|
||||
) -> T:
|
||||
"""Get structured output from LLM with validation retry via Workflow."""
|
||||
workflow = StructuredOutputWorkflow(
|
||||
output_cls=output_cls,
|
||||
max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
result = await workflow.run(
|
||||
prompt=prompt,
|
||||
texts=texts,
|
||||
tone_name=tone_name,
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
error_msg = result["error"] or "Max retries exceeded"
|
||||
raise LLMParseError(
|
||||
async def run_workflow():
|
||||
workflow = StructuredOutputWorkflow(
|
||||
output_cls=output_cls,
|
||||
error_msg=error_msg,
|
||||
attempts=result.get("attempts", 0),
|
||||
max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
return result["success"]
|
||||
result = await workflow.run(
|
||||
prompt=prompt,
|
||||
texts=texts,
|
||||
tone_name=tone_name,
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
error_msg = result["error"] or "Max retries exceeded"
|
||||
raise LLMParseError(
|
||||
output_cls=output_cls,
|
||||
error_msg=error_msg,
|
||||
attempts=result.get("attempts", 0),
|
||||
)
|
||||
|
||||
return result["success"]
|
||||
|
||||
return await retry(run_workflow)(
|
||||
retry_attempts=3,
|
||||
retry_backoff_interval=1.0,
|
||||
retry_backoff_max=30.0,
|
||||
retry_ignore_exc_types=(WorkflowTimeoutError,),
|
||||
)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Tests for LLM parse error recovery using llama-index Workflow"""
|
||||
|
||||
from time import monotonic
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from workflows.errors import WorkflowRuntimeError
|
||||
from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError
|
||||
|
||||
from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow
|
||||
from reflector.utils.retry import RetryException
|
||||
|
||||
|
||||
class TestResponse(BaseModel):
|
||||
@@ -355,3 +357,132 @@ class TestNetworkErrorRetries:
|
||||
|
||||
# Only called once - Workflow doesn't retry network errors
|
||||
assert mock_settings.llm.acomplete.call_count == 1
|
||||
|
||||
|
||||
class TestWorkflowTimeoutRetry:
|
||||
"""Test timeout retry mechanism in get_structured_response"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_retry_succeeds_on_retry(self, test_settings):
|
||||
"""Test that WorkflowTimeoutError triggers retry and succeeds"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def workflow_run_side_effect(*args, **kwargs):
|
||||
call_count["count"] += 1
|
||||
if call_count["count"] == 1:
|
||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
||||
return {
|
||||
"success": TestResponse(
|
||||
title="Test", summary="Summary", confidence=0.95
|
||||
)
|
||||
}
|
||||
|
||||
with (
|
||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
||||
mock_workflow_class.return_value = mock_workflow
|
||||
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
)
|
||||
|
||||
result = await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
)
|
||||
|
||||
assert result.title == "Test"
|
||||
assert result.summary == "Summary"
|
||||
assert call_count["count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_retry_exhausts_after_max_attempts(self, test_settings):
|
||||
"""Test that timeout retry stops after max attempts"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def workflow_run_side_effect(*args, **kwargs):
|
||||
call_count["count"] += 1
|
||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
||||
|
||||
with (
|
||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
||||
mock_workflow_class.return_value = mock_workflow
|
||||
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(RetryException, match="Retry attempts exceeded"):
|
||||
await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
)
|
||||
|
||||
assert call_count["count"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_retry_with_backoff(self, test_settings):
|
||||
"""Test that exponential backoff is applied between retries"""
|
||||
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
|
||||
|
||||
call_times = []
|
||||
|
||||
async def workflow_run_side_effect(*args, **kwargs):
|
||||
call_times.append(monotonic())
|
||||
if len(call_times) < 3:
|
||||
raise WorkflowTimeoutError("Operation timed out after 120 seconds")
|
||||
return {
|
||||
"success": TestResponse(
|
||||
title="Test", summary="Summary", confidence=0.95
|
||||
)
|
||||
}
|
||||
|
||||
with (
|
||||
patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class,
|
||||
patch("reflector.llm.TreeSummarize") as mock_summarize,
|
||||
patch("reflector.llm.Settings") as mock_settings,
|
||||
):
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect)
|
||||
mock_workflow_class.return_value = mock_workflow
|
||||
|
||||
mock_summarizer = MagicMock()
|
||||
mock_summarize.return_value = mock_summarizer
|
||||
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
|
||||
mock_settings.llm.acomplete = AsyncMock(
|
||||
return_value=make_completion_response(
|
||||
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
|
||||
)
|
||||
)
|
||||
|
||||
result = await llm.get_structured_response(
|
||||
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
|
||||
)
|
||||
|
||||
assert result.title == "Test"
|
||||
if len(call_times) >= 2:
|
||||
time_between_calls = call_times[1] - call_times[0]
|
||||
assert (
|
||||
time_between_calls >= 1.5
|
||||
), f"Expected ~2s backoff, got {time_between_calls}s"
|
||||
|
||||
Reference in New Issue
Block a user