From 5f7dfadabd3e8017406ad3720ba495a59963ee34 Mon Sep 17 00:00:00 2001 From: Sergey Mankovsky Date: Thu, 18 Dec 2025 20:49:06 +0100 Subject: [PATCH] fix: retry on workflow timeout (#798) --- server/reflector/llm.py | 46 +++++++----- server/tests/test_llm_retry.py | 133 ++++++++++++++++++++++++++++++++- 2 files changed, 161 insertions(+), 18 deletions(-) diff --git a/server/reflector/llm.py b/server/reflector/llm.py index 0485e847..10ba9138 100644 --- a/server/reflector/llm.py +++ b/server/reflector/llm.py @@ -16,6 +16,9 @@ from llama_index.core.workflow import ( ) from llama_index.llms.openai_like import OpenAILike from pydantic import BaseModel, ValidationError +from workflows.errors import WorkflowTimeoutError + +from reflector.utils.retry import retry T = TypeVar("T", bound=BaseModel) OutputT = TypeVar("OutputT", bound=BaseModel) @@ -231,24 +234,33 @@ class LLM: tone_name: str | None = None, ) -> T: """Get structured output from LLM with validation retry via Workflow.""" - workflow = StructuredOutputWorkflow( - output_cls=output_cls, - max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1, - timeout=120, - ) - result = await workflow.run( - prompt=prompt, - texts=texts, - tone_name=tone_name, - ) - - if "error" in result: - error_msg = result["error"] or "Max retries exceeded" - raise LLMParseError( + async def run_workflow(): + workflow = StructuredOutputWorkflow( output_cls=output_cls, - error_msg=error_msg, - attempts=result.get("attempts", 0), + max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1, + timeout=120, ) - return result["success"] + result = await workflow.run( + prompt=prompt, + texts=texts, + tone_name=tone_name, + ) + + if "error" in result: + error_msg = result["error"] or "Max retries exceeded" + raise LLMParseError( + output_cls=output_cls, + error_msg=error_msg, + attempts=result.get("attempts", 0), + ) + + return result["success"] + + return await retry(run_workflow)( + retry_attempts=3, + retry_backoff_interval=1.0, + retry_backoff_max=30.0, + retry_ignore_exc_types=(WorkflowTimeoutError,), + ) diff --git a/server/tests/test_llm_retry.py b/server/tests/test_llm_retry.py index f9fe28b4..5a43c8c5 100644 --- a/server/tests/test_llm_retry.py +++ b/server/tests/test_llm_retry.py @@ -1,12 +1,14 @@ """Tests for LLM parse error recovery using llama-index Workflow""" +from time import monotonic from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import BaseModel, Field -from workflows.errors import WorkflowRuntimeError +from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow +from reflector.utils.retry import RetryException class TestResponse(BaseModel): @@ -355,3 +357,132 @@ class TestNetworkErrorRetries: # Only called once - Workflow doesn't retry network errors assert mock_settings.llm.acomplete.call_count == 1 + + +class TestWorkflowTimeoutRetry: + """Test timeout retry mechanism in get_structured_response""" + + @pytest.mark.asyncio + async def test_timeout_retry_succeeds_on_retry(self, test_settings): + """Test that WorkflowTimeoutError triggers retry and succeeds""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + call_count = {"count": 0} + + async def workflow_run_side_effect(*args, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + raise WorkflowTimeoutError("Operation timed out after 120 seconds") + return { + "success": TestResponse( + title="Test", summary="Summary", confidence=0.95 + ) + } + + with ( + patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class, + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_workflow = MagicMock() + mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect) + mock_workflow_class.return_value = mock_workflow + + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + mock_settings.llm.acomplete = AsyncMock( + return_value=make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + ) + ) + + result = await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + assert result.title == "Test" + assert result.summary == "Summary" + assert call_count["count"] == 2 + + @pytest.mark.asyncio + async def test_timeout_retry_exhausts_after_max_attempts(self, test_settings): + """Test that timeout retry stops after max attempts""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + call_count = {"count": 0} + + async def workflow_run_side_effect(*args, **kwargs): + call_count["count"] += 1 + raise WorkflowTimeoutError("Operation timed out after 120 seconds") + + with ( + patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class, + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_workflow = MagicMock() + mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect) + mock_workflow_class.return_value = mock_workflow + + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + mock_settings.llm.acomplete = AsyncMock( + return_value=make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + ) + ) + + with pytest.raises(RetryException, match="Retry attempts exceeded"): + await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + assert call_count["count"] == 3 + + @pytest.mark.asyncio + async def test_timeout_retry_with_backoff(self, test_settings): + """Test that exponential backoff is applied between retries""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + call_times = [] + + async def workflow_run_side_effect(*args, **kwargs): + call_times.append(monotonic()) + if len(call_times) < 3: + raise WorkflowTimeoutError("Operation timed out after 120 seconds") + return { + "success": TestResponse( + title="Test", summary="Summary", confidence=0.95 + ) + } + + with ( + patch("reflector.llm.StructuredOutputWorkflow") as mock_workflow_class, + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_workflow = MagicMock() + mock_workflow.run = AsyncMock(side_effect=workflow_run_side_effect) + mock_workflow_class.return_value = mock_workflow + + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + mock_settings.llm.acomplete = AsyncMock( + return_value=make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + ) + ) + + result = await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + assert result.title == "Test" + if len(call_times) >= 2: + time_between_calls = call_times[1] - call_times[0] + assert ( + time_between_calls >= 1.5 + ), f"Expected ~2s backoff, got {time_between_calls}s"