feat: llm retries (#739)

* llm retries no-mistakes

* self-review (no-mistakes)

* self-review (no-mistakes)

* bigger retry intervals by default

* tests and dry

* restore to main state

* parse retries

* json retries (no-mistakes)

* json retries (no-mistakes)

* json retries (no-mistakes)

* json retries (no-mistakes) self-review

* additional network retry test

* more lindt

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
Igor Monadical
2025-12-05 12:08:21 -05:00
committed by GitHub
parent ec17ed7b58
commit 61f0e29d4c
7 changed files with 564 additions and 38 deletions

View File

@@ -126,6 +126,7 @@ markers = [
select = [ select = [
"I", # isort - import sorting "I", # isort - import sorting
"F401", # unused imports "F401", # unused imports
"E402", # module level import not at top of file
"PLC0415", # import-outside-top-level - detect inline imports "PLC0415", # import-outside-top-level - detect inline imports
] ]

View File

@@ -1,13 +1,19 @@
import asyncio import asyncio
import functools import functools
from uuid import uuid4
from celery import current_task
from reflector.db import get_database from reflector.db import get_database
from reflector.llm import llm_session_id
def asynctask(f): def asynctask(f):
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
async def run_with_db(): async def run_with_db():
task_id = current_task.request.id if current_task else None
llm_session_id.set(task_id or f"random-{uuid4().hex}")
database = get_database() database = get_database()
await database.connect() await database.connect()
try: try:

View File

@@ -1,14 +1,29 @@
import logging import logging
from typing import Type, TypeVar from contextvars import ContextVar
from typing import Generic, Type, TypeVar
from uuid import uuid4
from llama_index.core import Settings from llama_index.core import Settings
from llama_index.core.output_parsers import PydanticOutputParser from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.response_synthesizers import TreeSummarize from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.llms.openai_like import OpenAILike from llama_index.llms.openai_like import OpenAILike
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)
# Session ID for LiteLLM request grouping - set per processing run
llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None)
logger = logging.getLogger(__name__)
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """ STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
Based on the following analysis, provide the information in the requested JSON format: Based on the following analysis, provide the information in the requested JSON format:
@@ -20,6 +35,158 @@ Analysis:
""" """
class LLMParseError(Exception):
"""Raised when LLM output cannot be parsed after retries."""
def __init__(self, output_cls: Type[BaseModel], error_msg: str, attempts: int):
self.output_cls = output_cls
self.error_msg = error_msg
self.attempts = attempts
super().__init__(
f"Failed to parse {output_cls.__name__} after {attempts} attempts: {error_msg}"
)
class ExtractionDone(Event):
"""Event emitted when LLM JSON formatting completes."""
output: str
class ValidationErrorEvent(Event):
"""Event emitted when validation fails."""
error: str
wrong_output: str
class StructuredOutputWorkflow(Workflow, Generic[OutputT]):
"""Workflow for structured output extraction with validation retry.
This workflow handles parse/validation retries only. Network error retries
are handled internally by Settings.llm (OpenAILike max_retries=3).
The caller should NOT wrap this workflow in additional retry logic.
"""
def __init__(
self,
output_cls: Type[OutputT],
max_retries: int = 3,
**kwargs,
):
super().__init__(**kwargs)
self.output_cls: Type[OutputT] = output_cls
self.max_retries = max_retries
self.output_parser = PydanticOutputParser(output_cls)
@step
async def extract(
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
) -> StopEvent | ExtractionDone:
"""Extract structured data from text using two-step LLM process.
Step 1 (first call only): TreeSummarize generates text analysis
Step 2 (every call): Settings.llm.acomplete formats analysis as JSON
"""
current_retries = await ctx.store.get("retries", default=0)
await ctx.store.set("retries", current_retries + 1)
if current_retries >= self.max_retries:
last_error = await ctx.store.get("last_error", default=None)
logger.error(
f"Max retries ({self.max_retries}) reached for {self.output_cls.__name__}"
)
return StopEvent(result={"error": last_error, "attempts": current_retries})
if isinstance(ev, StartEvent):
# First call: run TreeSummarize to get analysis, store in context
prompt = ev.get("prompt")
texts = ev.get("texts")
tone_name = ev.get("tone_name")
if not prompt or not isinstance(texts, list):
raise ValueError(
"StartEvent must contain 'prompt' (str) and 'texts' (list)"
)
summarizer = TreeSummarize(verbose=False)
analysis = await summarizer.aget_response(
prompt, texts, tone_name=tone_name
)
await ctx.store.set("analysis", str(analysis))
reflection = ""
else:
# Retry: reuse analysis from context
analysis = await ctx.store.get("analysis")
if not analysis:
raise RuntimeError("Internal error: analysis not found in context")
wrong_output = ev.wrong_output
if len(wrong_output) > 2000:
wrong_output = wrong_output[:2000] + "... [truncated]"
reflection = (
f"\n\nYour previous response could not be parsed:\n{wrong_output}\n\n"
f"Error:\n{ev.error}\n\n"
"Please try again. Return ONLY valid JSON matching the schema above, "
"with no markdown formatting or extra text."
)
# Step 2: Format analysis as JSON using LLM completion
format_instructions = self.output_parser.format(
"Please structure the above information in the following JSON format:"
)
json_prompt = STRUCTURED_RESPONSE_PROMPT_TEMPLATE.format(
analysis=analysis,
format_instructions=format_instructions + reflection,
)
# Network retries handled by OpenAILike (max_retries=3)
response = await Settings.llm.acomplete(json_prompt)
return ExtractionDone(output=response.text)
@step
async def validate(
self, ctx: Context, ev: ExtractionDone
) -> StopEvent | ValidationErrorEvent:
"""Validate extracted output against Pydantic schema."""
raw_output = ev.output
retries = await ctx.store.get("retries", default=0)
try:
parsed = self.output_parser.parse(raw_output)
if retries > 1:
logger.info(
f"LLM parse succeeded on attempt {retries}/{self.max_retries} "
f"for {self.output_cls.__name__}"
)
return StopEvent(result={"success": parsed})
except (ValidationError, ValueError) as e:
error_msg = self._format_error(e, raw_output)
await ctx.store.set("last_error", error_msg)
logger.error(
f"LLM parse error (attempt {retries}/{self.max_retries}): "
f"{type(e).__name__}: {e}\nRaw response: {raw_output[:500]}"
)
return ValidationErrorEvent(
error=error_msg,
wrong_output=raw_output,
)
def _format_error(self, error: Exception, raw_output: str) -> str:
"""Format error for LLM feedback."""
if isinstance(error, ValidationError):
error_messages = []
for err in error.errors():
field = ".".join(str(loc) for loc in err["loc"])
error_messages.append(f"- {err['msg']} in field '{field}'")
return "Schema validation errors:\n" + "\n".join(error_messages)
else:
return f"Parse error: {str(error)}"
class LLM: class LLM:
def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048): def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
self.settings_obj = settings self.settings_obj = settings
@@ -30,11 +197,12 @@ class LLM:
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
# Configure llamaindex Settings
self._configure_llamaindex() self._configure_llamaindex()
def _configure_llamaindex(self): def _configure_llamaindex(self):
"""Configure llamaindex Settings with OpenAILike LLM""" """Configure llamaindex Settings with OpenAILike LLM"""
session_id = llm_session_id.get() or f"fallback-{uuid4().hex}"
Settings.llm = OpenAILike( Settings.llm = OpenAILike(
model=self.model_name, model=self.model_name,
api_base=self.url, api_base=self.url,
@@ -44,6 +212,7 @@ class LLM:
is_function_calling_model=False, is_function_calling_model=False,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
additional_kwargs={"extra_body": {"litellm_session_id": session_id}},
) )
async def get_response( async def get_response(
@@ -61,43 +230,25 @@ class LLM:
output_cls: Type[T], output_cls: Type[T],
tone_name: str | None = None, tone_name: str | None = None,
) -> T: ) -> T:
"""Get structured output from LLM for non-function-calling models""" """Get structured output from LLM with validation retry via Workflow."""
logger = logging.getLogger(__name__) workflow = StructuredOutputWorkflow(
output_cls=output_cls,
summarizer = TreeSummarize(verbose=True) max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1,
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name) timeout=120,
output_parser = PydanticOutputParser(output_cls)
program = LLMTextCompletionProgram.from_defaults(
output_parser=output_parser,
prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
verbose=False,
) )
format_instructions = output_parser.format( result = await workflow.run(
"Please structure the above information in the following JSON format:" prompt=prompt,
texts=texts,
tone_name=tone_name,
) )
try: if "error" in result:
output = await program.acall( error_msg = result["error"] or "Max retries exceeded"
analysis=str(response), format_instructions=format_instructions raise LLMParseError(
output_cls=output_cls,
error_msg=error_msg,
attempts=result.get("attempts", 0),
) )
except ValidationError as e:
# Extract the raw JSON from the error details
errors = e.errors()
if errors and "input" in errors[0]:
raw_json = errors[0]["input"]
logger.error(
f"JSON validation failed for {output_cls.__name__}. "
f"Full raw JSON output:\n{raw_json}\n"
f"Validation errors: {errors}"
)
else:
logger.error(
f"JSON validation failed for {output_cls.__name__}. "
f"Validation errors: {errors}"
)
raise
return output return result["success"]

View File

@@ -340,7 +340,6 @@ async def task_send_webhook_if_needed(*, transcript_id: str):
@asynctask @asynctask
async def task_pipeline_file_process(*, transcript_id: str): async def task_pipeline_file_process(*, transcript_id: str):
"""Celery task for file pipeline processing""" """Celery task for file pipeline processing"""
transcript = await transcripts_controller.get_by_id(transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")

View File

@@ -74,6 +74,10 @@ class Settings(BaseSettings):
LLM_API_KEY: str | None = None LLM_API_KEY: str | None = None
LLM_CONTEXT_WINDOW: int = 16000 LLM_CONTEXT_WINDOW: int = 16000
LLM_PARSE_MAX_RETRIES: int = (
3 # Max retries for JSON/validation errors (total attempts = retries + 1)
)
# Diarization # Diarization
DIARIZATION_ENABLED: bool = True DIARIZATION_ENABLED: bool = True
DIARIZATION_BACKEND: str = "modal" DIARIZATION_BACKEND: str = "modal"

View File

@@ -318,6 +318,14 @@ async def dummy_storage():
yield yield
@pytest.fixture
def test_settings():
"""Provide isolated settings for tests to avoid modifying global settings"""
from reflector.settings import Settings
return Settings()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def celery_enable_logging(): def celery_enable_logging():
return True return True

View File

@@ -0,0 +1,357 @@
"""Tests for LLM parse error recovery using llama-index Workflow"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel, Field
from workflows.errors import WorkflowRuntimeError
from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow
class TestResponse(BaseModel):
"""Test response model for structured output"""
title: str = Field(description="A title")
summary: str = Field(description="A summary")
confidence: float = Field(description="Confidence score", ge=0, le=1)
def make_completion_response(text: str):
"""Create a mock CompletionResponse with .text attribute"""
response = MagicMock()
response.text = text
return response
class TestLLMParseErrorRecovery:
"""Test parse error recovery with Workflow feedback loop"""
@pytest.mark.asyncio
async def test_parse_error_recovery_with_feedback(self, test_settings):
"""Test that parse errors trigger retry with error feedback"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
# TreeSummarize returns plain text analysis (step 1)
mock_summarizer.aget_response = AsyncMock(
return_value="The analysis shows a test with summary and high confidence."
)
call_count = {"count": 0}
async def acomplete_handler(prompt, *args, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# First JSON formatting call returns invalid JSON
return make_completion_response('{"title": "Test"}')
else:
# Second call should have error feedback in prompt
assert "Your previous response could not be parsed:" in prompt
assert '{"title": "Test"}' in prompt
assert "Error:" in prompt
assert "Please try again" in prompt
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert result.summary == "Summary"
assert result.confidence == 0.95
# TreeSummarize called once, Settings.llm.acomplete called twice
assert mock_summarizer.aget_response.call_count == 1
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_max_parse_retry_attempts(self, test_settings):
"""Test that parse error retry stops after max attempts"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Always return invalid JSON from acomplete
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response(
'{"invalid": "missing required fields"}'
)
)
with pytest.raises(LLMParseError, match="Failed to parse"):
await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1
# TreeSummarize called once, acomplete called max_retries times
assert mock_summarizer.aget_response.call_count == 1
assert mock_settings.llm.acomplete.call_count == expected_attempts
@pytest.mark.asyncio
async def test_raw_response_logging_on_parse_error(self, test_settings, caplog):
"""Test that raw response is logged when parse error occurs"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
caplog.at_level("ERROR"),
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
call_count = {"count": 0}
async def acomplete_handler(*args, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
return make_completion_response('{"title": "Test"}') # Invalid
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
error_logs = [r for r in caplog.records if r.levelname == "ERROR"]
raw_response_logged = any("Raw response:" in r.message for r in error_logs)
assert raw_response_logged, "Raw response should be logged on parse error"
@pytest.mark.asyncio
async def test_multiple_validation_errors_in_feedback(self, test_settings):
"""Test that validation errors are included in feedback"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
call_count = {"count": 0}
async def acomplete_handler(prompt, *args, **kwargs):
call_count["count"] += 1
if call_count["count"] == 1:
# Missing title and summary
return make_completion_response('{"confidence": 0.5}')
else:
# Should have schema validation errors in prompt
assert (
"Schema validation errors" in prompt
or "error" in prompt.lower()
)
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_success_on_first_attempt(self, test_settings):
"""Test that no retry happens when first attempt succeeds"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.95}'
)
)
result = await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
assert result.title == "Test"
assert result.summary == "Summary"
assert result.confidence == 0.95
assert mock_summarizer.aget_response.call_count == 1
assert mock_settings.llm.acomplete.call_count == 1
class TestStructuredOutputWorkflow:
"""Direct tests for the StructuredOutputWorkflow"""
@pytest.mark.asyncio
async def test_workflow_retries_on_validation_error(self):
"""Test workflow retries when validation fails"""
workflow = StructuredOutputWorkflow(
output_cls=TestResponse,
max_retries=3,
timeout=30,
)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
call_count = {"count": 0}
async def acomplete_handler(*args, **kwargs):
call_count["count"] += 1
if call_count["count"] < 2:
return make_completion_response('{"title": "Only title"}')
return make_completion_response(
'{"title": "Test", "summary": "Summary", "confidence": 0.9}'
)
mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler)
result = await workflow.run(
prompt="Extract data",
texts=["Some text"],
tone_name=None,
)
assert "success" in result
assert result["success"].title == "Test"
assert call_count["count"] == 2
@pytest.mark.asyncio
async def test_workflow_returns_error_after_max_retries(self):
"""Test workflow returns error after exhausting retries"""
workflow = StructuredOutputWorkflow(
output_cls=TestResponse,
max_retries=2,
timeout=30,
)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Always return invalid JSON
mock_settings.llm.acomplete = AsyncMock(
return_value=make_completion_response('{"invalid": true}')
)
result = await workflow.run(
prompt="Extract data",
texts=["Some text"],
tone_name=None,
)
assert "error" in result
# TreeSummarize called once, acomplete called max_retries times
assert mock_summarizer.aget_response.call_count == 1
assert mock_settings.llm.acomplete.call_count == 2
class TestNetworkErrorRetries:
"""Test that network error retries are handled by OpenAILike, not Workflow"""
@pytest.mark.asyncio
async def test_network_error_propagates_after_openai_retries(self, test_settings):
"""Test that network errors are retried by OpenAILike and then propagate.
Network retries are handled by OpenAILike (max_retries=3), not by our
StructuredOutputWorkflow. This test verifies that network errors propagate
up after OpenAILike exhausts its retries.
"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Simulate network error from acomplete (after OpenAILike retries exhausted)
network_error = ConnectionError("Connection refused")
mock_settings.llm.acomplete = AsyncMock(side_effect=network_error)
# Network error wrapped in WorkflowRuntimeError
with pytest.raises(WorkflowRuntimeError, match="Connection refused"):
await llm.get_structured_response(
prompt="Test prompt", texts=["Test text"], output_cls=TestResponse
)
# acomplete called only once - network error propagates, not retried by Workflow
assert mock_settings.llm.acomplete.call_count == 1
@pytest.mark.asyncio
async def test_network_error_not_retried_by_workflow(self, test_settings):
"""Test that Workflow does NOT retry network errors (OpenAILike handles those).
This verifies the separation of concerns:
- StructuredOutputWorkflow: retries parse/validation errors
- OpenAILike: retries network errors (internally, max_retries=3)
"""
workflow = StructuredOutputWorkflow(
output_cls=TestResponse,
max_retries=3,
timeout=30,
)
with (
patch("reflector.llm.TreeSummarize") as mock_summarize,
patch("reflector.llm.Settings") as mock_settings,
):
mock_summarizer = MagicMock()
mock_summarize.return_value = mock_summarizer
mock_summarizer.aget_response = AsyncMock(return_value="Some analysis")
# Network error should propagate immediately, not trigger Workflow retry
mock_settings.llm.acomplete = AsyncMock(
side_effect=TimeoutError("Request timed out")
)
# Network error wrapped in WorkflowRuntimeError
with pytest.raises(WorkflowRuntimeError, match="Request timed out"):
await workflow.run(
prompt="Extract data",
texts=["Some text"],
tone_name=None,
)
# Only called once - Workflow doesn't retry network errors
assert mock_settings.llm.acomplete.call_count == 1