fix: improve hatchet workflow reliability (#900)

* Increase max connections

* Classify hard and transient hatchet errors

* Fan out partial success

* Force reprocessing of error transcripts

* Stop retrying on 402 payment required

* Avoid httpx/hatchet timeout race

* Add retry wrapper to get_response for for transient errors

* Add retry backoff

* Return falsy results so get_response won't retry on empty string

* Skip error status in on_workflow_failure when transcript already ended

* Fix precommit issues

* Fail step on first fan-out failure instead of skipping
This commit is contained in:
Sergey Mankovsky
2026-03-06 17:07:26 +01:00
committed by GitHub
parent a682846645
commit c155f66982
17 changed files with 717 additions and 38 deletions

View File

@@ -0,0 +1,303 @@
"""
Tests for Hatchet error handling: NonRetryable classification and error status.
These tests encode the desired behavior from the Hatchet Workflow Analysis doc:
- Transient exceptions: do NOT set error status (let Hatchet retry; user stays on "processing").
- Hard-fail exceptions: set error status and re-raise as NonRetryableException (stop retries).
- on_failure_task: sets error status when workflow is truly dead.
Run before the fix: some tests fail (reproducing the issues).
Run after the fix: all tests pass.
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from hatchet_sdk import NonRetryableException
from reflector.hatchet.error_classification import is_non_retryable
from reflector.llm import LLMParseError
# --- Tests for is_non_retryable() (pass once error_classification exists) ---
def test_is_non_retryable_returns_true_for_value_error():
"""ValueError (e.g. missing config) should stop retries."""
assert is_non_retryable(ValueError("DAILY_API_KEY must be set")) is True
def test_is_non_retryable_returns_true_for_type_error():
"""TypeError (bad input) should stop retries."""
assert is_non_retryable(TypeError("expected str")) is True
def test_is_non_retryable_returns_true_for_http_401():
"""HTTP 401 auth error should stop retries."""
resp = MagicMock()
resp.status_code = 401
err = httpx.HTTPStatusError("Unauthorized", request=MagicMock(), response=resp)
assert is_non_retryable(err) is True
def test_is_non_retryable_returns_true_for_http_402():
"""HTTP 402 (no credits) should stop retries."""
resp = MagicMock()
resp.status_code = 402
err = httpx.HTTPStatusError("Payment Required", request=MagicMock(), response=resp)
assert is_non_retryable(err) is True
def test_is_non_retryable_returns_true_for_http_404():
"""HTTP 404 should stop retries."""
resp = MagicMock()
resp.status_code = 404
err = httpx.HTTPStatusError("Not Found", request=MagicMock(), response=resp)
assert is_non_retryable(err) is True
def test_is_non_retryable_returns_false_for_http_503():
"""HTTP 503 is transient; retries are useful."""
resp = MagicMock()
resp.status_code = 503
err = httpx.HTTPStatusError(
"Service Unavailable", request=MagicMock(), response=resp
)
assert is_non_retryable(err) is False
def test_is_non_retryable_returns_false_for_timeout():
"""Timeout is transient."""
assert is_non_retryable(httpx.TimeoutException("timed out")) is False
def test_is_non_retryable_returns_true_for_llm_parse_error():
"""LLMParseError after internal retries should stop."""
from pydantic import BaseModel
class _Dummy(BaseModel):
pass
assert is_non_retryable(LLMParseError(_Dummy, "Failed to parse", 3)) is True
def test_is_non_retryable_returns_true_for_non_retryable_exception():
"""Already-wrapped NonRetryableException should stay non-retryable."""
assert is_non_retryable(NonRetryableException("custom")) is True
# --- Tests for with_error_handling (need pipeline module with patch) ---
@pytest.fixture(scope="module")
def pipeline_module():
"""Import daily_multitrack_pipeline with Hatchet client mocked."""
with patch("reflector.hatchet.client.settings") as s:
s.HATCHET_CLIENT_TOKEN = "test-token"
s.HATCHET_DEBUG = False
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
):
from reflector.hatchet.workflows import daily_multitrack_pipeline
return daily_multitrack_pipeline
@pytest.fixture
def mock_input():
"""Minimal PipelineInput for decorator tests."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import PipelineInput
return PipelineInput(
recording_id="rec-1",
tracks=[],
bucket_name="bucket",
transcript_id="ts-123",
room_id=None,
)
@pytest.fixture
def mock_ctx():
"""Minimal Context-like object."""
ctx = MagicMock()
ctx.log = MagicMock()
return ctx
@pytest.mark.asyncio
async def test_with_error_handling_transient_does_not_set_error_status(
pipeline_module, mock_input, mock_ctx
):
"""Transient exception must NOT set error status (so user stays on 'processing' during retries).
Before fix: set_workflow_error_status is called on every exception → FAIL.
After fix: not called for transient → PASS.
"""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise httpx.TimeoutException("timed out")
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(httpx.TimeoutException):
await wrapped(mock_input, mock_ctx)
# Desired: do NOT set error status for transient (Hatchet will retry)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_with_error_handling_hard_fail_raises_non_retryable_and_sets_status(
pipeline_module, mock_input, mock_ctx
):
"""Hard-fail (e.g. ValueError) must set error status and re-raise NonRetryableException.
Before fix: raises ValueError, set_workflow_error_status called → test would need to expect ValueError.
After fix: raises NonRetryableException, set_workflow_error_status called → PASS.
"""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("PADDING_URL must be set")
wrapped = with_error_handling(TaskName.GET_RECORDING)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(NonRetryableException) as exc_info:
await wrapped(mock_input, mock_ctx)
assert "PADDING_URL" in str(exc_info.value)
mock_set_error.assert_called_once_with("ts-123")
@pytest.mark.asyncio
async def test_with_error_handling_set_error_status_false_never_sets_status(
pipeline_module, mock_input, mock_ctx
):
"""When set_error_status=False, we must never set error status (e.g. cleanup_consent)."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("something went wrong")
wrapped = with_error_handling(TaskName.CLEANUP_CONSENT, set_error_status=False)(
failing_task
)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises((ValueError, NonRetryableException)):
await wrapped(mock_input, mock_ctx)
mock_set_error.assert_not_called()
@asynccontextmanager
async def _noop_db_context():
"""Async context manager that yields without touching the DB (for unit tests)."""
yield None
@pytest.mark.asyncio
async def test_on_failure_task_sets_error_status(pipeline_module, mock_input, mock_ctx):
"""When workflow fails and transcript is not yet 'ended', on_failure sets status to 'error'."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
on_workflow_failure,
)
transcript_processing = MagicMock()
transcript_processing.status = "processing"
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_processing,
):
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_input, mock_ctx)
mock_set_error.assert_called_once_with(mock_input.transcript_id)
@pytest.mark.asyncio
async def test_on_failure_task_does_not_overwrite_ended(
pipeline_module, mock_input, mock_ctx
):
"""When workflow fails after finalize (e.g. post_zulip), do not overwrite 'ended' with 'error'.
cleanup_consent, post_zulip, send_webhook use set_error_status=False; if one fails,
on_workflow_failure must not set status to 'error' when transcript is already 'ended'.
"""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
on_workflow_failure,
)
transcript_ended = MagicMock()
transcript_ended.status = "ended"
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_ended,
):
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_input, mock_ctx)
mock_set_error.assert_not_called()
# --- Tests for fan-out helper (_successful_run_results) ---
def test_successful_run_results_filters_exceptions():
"""_successful_run_results returns only non-exception items from aio_run_many(return_exceptions=True)."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
_successful_run_results,
)
results = [
{"key": "ok1"},
ValueError("child failed"),
{"key": "ok2"},
RuntimeError("another"),
]
successful = _successful_run_results(results)
assert len(successful) == 2
assert successful[0] == {"key": "ok1"}
assert successful[1] == {"key": "ok2"}

View File

@@ -1,6 +1,6 @@
"""Tests for LLM structured output with astructured_predict + reflection retry"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel, Field, ValidationError
@@ -252,6 +252,63 @@ class TestNetworkErrorRetries:
assert mock_settings.llm.astructured_predict.call_count == 3
class TestGetResponseRetries:
"""Test that get_response() uses the same retry() wrapper for transient errors."""
@pytest.mark.asyncio
async def test_get_response_retries_on_connection_error(self, test_settings):
"""Test that get_response retries on ConnectionError and returns on success."""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
mock_instance = MagicMock()
mock_instance.aget_response = AsyncMock(
side_effect=[
ConnectionError("Connection refused"),
" Summary text ",
]
)
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
result = await llm.get_response("Prompt", ["text"])
assert result == "Summary text"
assert mock_instance.aget_response.call_count == 2
@pytest.mark.asyncio
async def test_get_response_exhausts_retries(self, test_settings):
"""Test that get_response raises RetryException after retry attempts exceeded."""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
mock_instance = MagicMock()
mock_instance.aget_response = AsyncMock(
side_effect=ConnectionError("Connection refused")
)
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
with pytest.raises(RetryException, match="Retry attempts exceeded"):
await llm.get_response("Prompt", ["text"])
assert mock_instance.aget_response.call_count == 3
@pytest.mark.asyncio
async def test_get_response_returns_empty_string_without_retry(self, test_settings):
"""Empty or whitespace-only LLM response must return '' and not raise RetryException.
retry() must return falsy results (e.g. '' from get_response) instead of
treating them as 'no result' and retrying until RetryException.
"""
llm = LLM(settings=test_settings, temperature=0.4, max_tokens=100)
mock_instance = MagicMock()
mock_instance.aget_response = AsyncMock(return_value=" \n ") # strip() -> ""
with patch("reflector.llm.TreeSummarize", return_value=mock_instance):
result = await llm.get_response("Prompt", ["text"])
assert result == ""
assert mock_instance.aget_response.call_count == 1
class TestTextsInclusion:
"""Test that texts parameter is included in the prompt sent to astructured_predict"""

View File

@@ -49,6 +49,15 @@ async def test_retry_httpx(httpx_mock):
)
@pytest.mark.asyncio
async def test_retry_402_stops_by_default(httpx_mock):
"""402 (payment required / no credits) is in default retry_httpx_status_stop — do not retry."""
httpx_mock.add_response(status_code=402, json={"error": "insufficient_credits"})
async with httpx.AsyncClient() as client:
with pytest.raises(RetryHTTPException):
await retry(client.get)("https://test_url", retry_timeout=5)
@pytest.mark.asyncio
async def test_retry_normal():
left = 3

View File

@@ -231,3 +231,81 @@ async def test_dailyco_recording_uses_multitrack_pipeline(client):
{"s3_key": k} for k in track_keys
]
mock_file_pipeline.delay.assert_not_called()
@pytest.mark.usefixtures("setup_database")
@pytest.mark.asyncio
async def test_reprocess_error_transcript_passes_force(client):
"""When transcript status is 'error', reprocess passes force=True to start fresh workflow."""
from datetime import datetime, timezone
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
transcript = await transcripts_controller.add(
"",
source_kind="room",
source_language="en",
target_language="en",
user_id="test-user",
share_mode="public",
room_id=room.id,
)
track_keys = ["recordings/test-room/track1.webm"]
recording = await recordings_controller.create(
Recording(
bucket_name="daily-bucket",
object_key="recordings/test-room",
meeting_id="test-meeting",
track_keys=track_keys,
recorded_at=datetime.now(timezone.utc),
)
)
await transcripts_controller.update(
transcript,
{
"recording_id": recording.id,
"status": "error",
"workflow_run_id": "old-failed-run",
},
)
with (
patch(
"reflector.services.transcript_process.task_is_scheduled_or_active"
) as mock_celery,
patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet,
patch(
"reflector.views.transcripts_process.dispatch_transcript_processing",
new_callable=AsyncMock,
) as mock_dispatch,
):
mock_celery.return_value = False
from hatchet_sdk.clients.rest.models import V1TaskStatus
mock_hatchet.get_workflow_run_status = AsyncMock(
return_value=V1TaskStatus.FAILED
)
response = await client.post(f"/transcripts/{transcript.id}/process")
assert response.status_code == 200
mock_dispatch.assert_called_once()
assert mock_dispatch.call_args.kwargs["force"] is True