diff --git a/server/pyproject.toml b/server/pyproject.toml index ffa28d15..279d3386 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -126,6 +126,7 @@ markers = [ select = [ "I", # isort - import sorting "F401", # unused imports + "E402", # module level import not at top of file "PLC0415", # import-outside-top-level - detect inline imports ] diff --git a/server/reflector/asynctask.py b/server/reflector/asynctask.py index 61523a6f..50f25448 100644 --- a/server/reflector/asynctask.py +++ b/server/reflector/asynctask.py @@ -1,13 +1,19 @@ import asyncio import functools +from uuid import uuid4 + +from celery import current_task from reflector.db import get_database +from reflector.llm import llm_session_id def asynctask(f): @functools.wraps(f) def wrapper(*args, **kwargs): async def run_with_db(): + task_id = current_task.request.id if current_task else None + llm_session_id.set(task_id or f"random-{uuid4().hex}") database = get_database() await database.connect() try: diff --git a/server/reflector/llm.py b/server/reflector/llm.py index 09dab3d2..0485e847 100644 --- a/server/reflector/llm.py +++ b/server/reflector/llm.py @@ -1,14 +1,29 @@ import logging -from typing import Type, TypeVar +from contextvars import ContextVar +from typing import Generic, Type, TypeVar +from uuid import uuid4 from llama_index.core import Settings from llama_index.core.output_parsers import PydanticOutputParser -from llama_index.core.program import LLMTextCompletionProgram from llama_index.core.response_synthesizers import TreeSummarize +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) from llama_index.llms.openai_like import OpenAILike from pydantic import BaseModel, ValidationError T = TypeVar("T", bound=BaseModel) +OutputT = TypeVar("OutputT", bound=BaseModel) + +# Session ID for LiteLLM request grouping - set per processing run +llm_session_id: ContextVar[str | None] = ContextVar("llm_session_id", default=None) + +logger = logging.getLogger(__name__) STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """ Based on the following analysis, provide the information in the requested JSON format: @@ -20,6 +35,158 @@ Analysis: """ +class LLMParseError(Exception): + """Raised when LLM output cannot be parsed after retries.""" + + def __init__(self, output_cls: Type[BaseModel], error_msg: str, attempts: int): + self.output_cls = output_cls + self.error_msg = error_msg + self.attempts = attempts + super().__init__( + f"Failed to parse {output_cls.__name__} after {attempts} attempts: {error_msg}" + ) + + +class ExtractionDone(Event): + """Event emitted when LLM JSON formatting completes.""" + + output: str + + +class ValidationErrorEvent(Event): + """Event emitted when validation fails.""" + + error: str + wrong_output: str + + +class StructuredOutputWorkflow(Workflow, Generic[OutputT]): + """Workflow for structured output extraction with validation retry. + + This workflow handles parse/validation retries only. Network error retries + are handled internally by Settings.llm (OpenAILike max_retries=3). + The caller should NOT wrap this workflow in additional retry logic. + """ + + def __init__( + self, + output_cls: Type[OutputT], + max_retries: int = 3, + **kwargs, + ): + super().__init__(**kwargs) + self.output_cls: Type[OutputT] = output_cls + self.max_retries = max_retries + self.output_parser = PydanticOutputParser(output_cls) + + @step + async def extract( + self, ctx: Context, ev: StartEvent | ValidationErrorEvent + ) -> StopEvent | ExtractionDone: + """Extract structured data from text using two-step LLM process. + + Step 1 (first call only): TreeSummarize generates text analysis + Step 2 (every call): Settings.llm.acomplete formats analysis as JSON + """ + current_retries = await ctx.store.get("retries", default=0) + await ctx.store.set("retries", current_retries + 1) + + if current_retries >= self.max_retries: + last_error = await ctx.store.get("last_error", default=None) + logger.error( + f"Max retries ({self.max_retries}) reached for {self.output_cls.__name__}" + ) + return StopEvent(result={"error": last_error, "attempts": current_retries}) + + if isinstance(ev, StartEvent): + # First call: run TreeSummarize to get analysis, store in context + prompt = ev.get("prompt") + texts = ev.get("texts") + tone_name = ev.get("tone_name") + if not prompt or not isinstance(texts, list): + raise ValueError( + "StartEvent must contain 'prompt' (str) and 'texts' (list)" + ) + + summarizer = TreeSummarize(verbose=False) + analysis = await summarizer.aget_response( + prompt, texts, tone_name=tone_name + ) + await ctx.store.set("analysis", str(analysis)) + reflection = "" + else: + # Retry: reuse analysis from context + analysis = await ctx.store.get("analysis") + if not analysis: + raise RuntimeError("Internal error: analysis not found in context") + + wrong_output = ev.wrong_output + if len(wrong_output) > 2000: + wrong_output = wrong_output[:2000] + "... [truncated]" + reflection = ( + f"\n\nYour previous response could not be parsed:\n{wrong_output}\n\n" + f"Error:\n{ev.error}\n\n" + "Please try again. Return ONLY valid JSON matching the schema above, " + "with no markdown formatting or extra text." + ) + + # Step 2: Format analysis as JSON using LLM completion + format_instructions = self.output_parser.format( + "Please structure the above information in the following JSON format:" + ) + + json_prompt = STRUCTURED_RESPONSE_PROMPT_TEMPLATE.format( + analysis=analysis, + format_instructions=format_instructions + reflection, + ) + + # Network retries handled by OpenAILike (max_retries=3) + response = await Settings.llm.acomplete(json_prompt) + return ExtractionDone(output=response.text) + + @step + async def validate( + self, ctx: Context, ev: ExtractionDone + ) -> StopEvent | ValidationErrorEvent: + """Validate extracted output against Pydantic schema.""" + raw_output = ev.output + retries = await ctx.store.get("retries", default=0) + + try: + parsed = self.output_parser.parse(raw_output) + if retries > 1: + logger.info( + f"LLM parse succeeded on attempt {retries}/{self.max_retries} " + f"for {self.output_cls.__name__}" + ) + return StopEvent(result={"success": parsed}) + + except (ValidationError, ValueError) as e: + error_msg = self._format_error(e, raw_output) + await ctx.store.set("last_error", error_msg) + + logger.error( + f"LLM parse error (attempt {retries}/{self.max_retries}): " + f"{type(e).__name__}: {e}\nRaw response: {raw_output[:500]}" + ) + + return ValidationErrorEvent( + error=error_msg, + wrong_output=raw_output, + ) + + def _format_error(self, error: Exception, raw_output: str) -> str: + """Format error for LLM feedback.""" + if isinstance(error, ValidationError): + error_messages = [] + for err in error.errors(): + field = ".".join(str(loc) for loc in err["loc"]) + error_messages.append(f"- {err['msg']} in field '{field}'") + return "Schema validation errors:\n" + "\n".join(error_messages) + else: + return f"Parse error: {str(error)}" + + class LLM: def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048): self.settings_obj = settings @@ -30,11 +197,12 @@ class LLM: self.temperature = temperature self.max_tokens = max_tokens - # Configure llamaindex Settings self._configure_llamaindex() def _configure_llamaindex(self): """Configure llamaindex Settings with OpenAILike LLM""" + session_id = llm_session_id.get() or f"fallback-{uuid4().hex}" + Settings.llm = OpenAILike( model=self.model_name, api_base=self.url, @@ -44,6 +212,7 @@ class LLM: is_function_calling_model=False, temperature=self.temperature, max_tokens=self.max_tokens, + additional_kwargs={"extra_body": {"litellm_session_id": session_id}}, ) async def get_response( @@ -61,43 +230,25 @@ class LLM: output_cls: Type[T], tone_name: str | None = None, ) -> T: - """Get structured output from LLM for non-function-calling models""" - logger = logging.getLogger(__name__) - - summarizer = TreeSummarize(verbose=True) - response = await summarizer.aget_response(prompt, texts, tone_name=tone_name) - - output_parser = PydanticOutputParser(output_cls) - - program = LLMTextCompletionProgram.from_defaults( - output_parser=output_parser, - prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE, - verbose=False, + """Get structured output from LLM with validation retry via Workflow.""" + workflow = StructuredOutputWorkflow( + output_cls=output_cls, + max_retries=self.settings_obj.LLM_PARSE_MAX_RETRIES + 1, + timeout=120, ) - format_instructions = output_parser.format( - "Please structure the above information in the following JSON format:" + result = await workflow.run( + prompt=prompt, + texts=texts, + tone_name=tone_name, ) - try: - output = await program.acall( - analysis=str(response), format_instructions=format_instructions + if "error" in result: + error_msg = result["error"] or "Max retries exceeded" + raise LLMParseError( + output_cls=output_cls, + error_msg=error_msg, + attempts=result.get("attempts", 0), ) - except ValidationError as e: - # Extract the raw JSON from the error details - errors = e.errors() - if errors and "input" in errors[0]: - raw_json = errors[0]["input"] - logger.error( - f"JSON validation failed for {output_cls.__name__}. " - f"Full raw JSON output:\n{raw_json}\n" - f"Validation errors: {errors}" - ) - else: - logger.error( - f"JSON validation failed for {output_cls.__name__}. " - f"Validation errors: {errors}" - ) - raise - return output + return result["success"] diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index 6f8e8011..aff6e042 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -340,7 +340,6 @@ async def task_send_webhook_if_needed(*, transcript_id: str): @asynctask async def task_pipeline_file_process(*, transcript_id: str): """Celery task for file pipeline processing""" - transcript = await transcripts_controller.get_by_id(transcript_id) if not transcript: raise Exception(f"Transcript {transcript_id} not found") diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 1ec46d94..12276121 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -74,6 +74,10 @@ class Settings(BaseSettings): LLM_API_KEY: str | None = None LLM_CONTEXT_WINDOW: int = 16000 + LLM_PARSE_MAX_RETRIES: int = ( + 3 # Max retries for JSON/validation errors (total attempts = retries + 1) + ) + # Diarization DIARIZATION_ENABLED: bool = True DIARIZATION_BACKEND: str = "modal" diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 7d6c4302..2931a0c2 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -318,6 +318,14 @@ async def dummy_storage(): yield +@pytest.fixture +def test_settings(): + """Provide isolated settings for tests to avoid modifying global settings""" + from reflector.settings import Settings + + return Settings() + + @pytest.fixture(scope="session") def celery_enable_logging(): return True diff --git a/server/tests/test_llm_retry.py b/server/tests/test_llm_retry.py new file mode 100644 index 00000000..f9fe28b4 --- /dev/null +++ b/server/tests/test_llm_retry.py @@ -0,0 +1,357 @@ +"""Tests for LLM parse error recovery using llama-index Workflow""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel, Field +from workflows.errors import WorkflowRuntimeError + +from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow + + +class TestResponse(BaseModel): + """Test response model for structured output""" + + title: str = Field(description="A title") + summary: str = Field(description="A summary") + confidence: float = Field(description="Confidence score", ge=0, le=1) + + +def make_completion_response(text: str): + """Create a mock CompletionResponse with .text attribute""" + response = MagicMock() + response.text = text + return response + + +class TestLLMParseErrorRecovery: + """Test parse error recovery with Workflow feedback loop""" + + @pytest.mark.asyncio + async def test_parse_error_recovery_with_feedback(self, test_settings): + """Test that parse errors trigger retry with error feedback""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + # TreeSummarize returns plain text analysis (step 1) + mock_summarizer.aget_response = AsyncMock( + return_value="The analysis shows a test with summary and high confidence." + ) + + call_count = {"count": 0} + + async def acomplete_handler(prompt, *args, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + # First JSON formatting call returns invalid JSON + return make_completion_response('{"title": "Test"}') + else: + # Second call should have error feedback in prompt + assert "Your previous response could not be parsed:" in prompt + assert '{"title": "Test"}' in prompt + assert "Error:" in prompt + assert "Please try again" in prompt + return make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + ) + + mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) + + result = await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + assert result.title == "Test" + assert result.summary == "Summary" + assert result.confidence == 0.95 + # TreeSummarize called once, Settings.llm.acomplete called twice + assert mock_summarizer.aget_response.call_count == 1 + assert call_count["count"] == 2 + + @pytest.mark.asyncio + async def test_max_parse_retry_attempts(self, test_settings): + """Test that parse error retry stops after max attempts""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + # Always return invalid JSON from acomplete + mock_settings.llm.acomplete = AsyncMock( + return_value=make_completion_response( + '{"invalid": "missing required fields"}' + ) + ) + + with pytest.raises(LLMParseError, match="Failed to parse"): + await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + expected_attempts = test_settings.LLM_PARSE_MAX_RETRIES + 1 + # TreeSummarize called once, acomplete called max_retries times + assert mock_summarizer.aget_response.call_count == 1 + assert mock_settings.llm.acomplete.call_count == expected_attempts + + @pytest.mark.asyncio + async def test_raw_response_logging_on_parse_error(self, test_settings, caplog): + """Test that raw response is logged when parse error occurs""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + caplog.at_level("ERROR"), + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + call_count = {"count": 0} + + async def acomplete_handler(*args, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + return make_completion_response('{"title": "Test"}') # Invalid + return make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + ) + + mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) + + result = await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + assert result.title == "Test" + + error_logs = [r for r in caplog.records if r.levelname == "ERROR"] + raw_response_logged = any("Raw response:" in r.message for r in error_logs) + assert raw_response_logged, "Raw response should be logged on parse error" + + @pytest.mark.asyncio + async def test_multiple_validation_errors_in_feedback(self, test_settings): + """Test that validation errors are included in feedback""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + call_count = {"count": 0} + + async def acomplete_handler(prompt, *args, **kwargs): + call_count["count"] += 1 + if call_count["count"] == 1: + # Missing title and summary + return make_completion_response('{"confidence": 0.5}') + else: + # Should have schema validation errors in prompt + assert ( + "Schema validation errors" in prompt + or "error" in prompt.lower() + ) + return make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + ) + + mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) + + result = await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + assert result.title == "Test" + assert call_count["count"] == 2 + + @pytest.mark.asyncio + async def test_success_on_first_attempt(self, test_settings): + """Test that no retry happens when first attempt succeeds""" + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + mock_settings.llm.acomplete = AsyncMock( + return_value=make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.95}' + ) + ) + + result = await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + assert result.title == "Test" + assert result.summary == "Summary" + assert result.confidence == 0.95 + assert mock_summarizer.aget_response.call_count == 1 + assert mock_settings.llm.acomplete.call_count == 1 + + +class TestStructuredOutputWorkflow: + """Direct tests for the StructuredOutputWorkflow""" + + @pytest.mark.asyncio + async def test_workflow_retries_on_validation_error(self): + """Test workflow retries when validation fails""" + workflow = StructuredOutputWorkflow( + output_cls=TestResponse, + max_retries=3, + timeout=30, + ) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + call_count = {"count": 0} + + async def acomplete_handler(*args, **kwargs): + call_count["count"] += 1 + if call_count["count"] < 2: + return make_completion_response('{"title": "Only title"}') + return make_completion_response( + '{"title": "Test", "summary": "Summary", "confidence": 0.9}' + ) + + mock_settings.llm.acomplete = AsyncMock(side_effect=acomplete_handler) + + result = await workflow.run( + prompt="Extract data", + texts=["Some text"], + tone_name=None, + ) + + assert "success" in result + assert result["success"].title == "Test" + assert call_count["count"] == 2 + + @pytest.mark.asyncio + async def test_workflow_returns_error_after_max_retries(self): + """Test workflow returns error after exhausting retries""" + workflow = StructuredOutputWorkflow( + output_cls=TestResponse, + max_retries=2, + timeout=30, + ) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + # Always return invalid JSON + mock_settings.llm.acomplete = AsyncMock( + return_value=make_completion_response('{"invalid": true}') + ) + + result = await workflow.run( + prompt="Extract data", + texts=["Some text"], + tone_name=None, + ) + + assert "error" in result + # TreeSummarize called once, acomplete called max_retries times + assert mock_summarizer.aget_response.call_count == 1 + assert mock_settings.llm.acomplete.call_count == 2 + + +class TestNetworkErrorRetries: + """Test that network error retries are handled by OpenAILike, not Workflow""" + + @pytest.mark.asyncio + async def test_network_error_propagates_after_openai_retries(self, test_settings): + """Test that network errors are retried by OpenAILike and then propagate. + + Network retries are handled by OpenAILike (max_retries=3), not by our + StructuredOutputWorkflow. This test verifies that network errors propagate + up after OpenAILike exhausts its retries. + """ + llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + # Simulate network error from acomplete (after OpenAILike retries exhausted) + network_error = ConnectionError("Connection refused") + mock_settings.llm.acomplete = AsyncMock(side_effect=network_error) + + # Network error wrapped in WorkflowRuntimeError + with pytest.raises(WorkflowRuntimeError, match="Connection refused"): + await llm.get_structured_response( + prompt="Test prompt", texts=["Test text"], output_cls=TestResponse + ) + + # acomplete called only once - network error propagates, not retried by Workflow + assert mock_settings.llm.acomplete.call_count == 1 + + @pytest.mark.asyncio + async def test_network_error_not_retried_by_workflow(self, test_settings): + """Test that Workflow does NOT retry network errors (OpenAILike handles those). + + This verifies the separation of concerns: + - StructuredOutputWorkflow: retries parse/validation errors + - OpenAILike: retries network errors (internally, max_retries=3) + """ + workflow = StructuredOutputWorkflow( + output_cls=TestResponse, + max_retries=3, + timeout=30, + ) + + with ( + patch("reflector.llm.TreeSummarize") as mock_summarize, + patch("reflector.llm.Settings") as mock_settings, + ): + mock_summarizer = MagicMock() + mock_summarize.return_value = mock_summarizer + mock_summarizer.aget_response = AsyncMock(return_value="Some analysis") + + # Network error should propagate immediately, not trigger Workflow retry + mock_settings.llm.acomplete = AsyncMock( + side_effect=TimeoutError("Request timed out") + ) + + # Network error wrapped in WorkflowRuntimeError + with pytest.raises(WorkflowRuntimeError, match="Request timed out"): + await workflow.run( + prompt="Extract data", + texts=["Some text"], + tone_name=None, + ) + + # Only called once - Workflow doesn't retry network errors + assert mock_settings.llm.acomplete.call_count == 1