mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-22 05:05:18 +00:00
test: add with_error_handling decorator tests for broadcast integration
This commit is contained in:
181
server/tests/test_dag_progress_decorator.py
Normal file
181
server/tests/test_dag_progress_decorator.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""Tests for with_error_handling decorator integration with broadcast_dag_status.
|
||||||
|
|
||||||
|
The decorator wraps each pipeline task and calls broadcast_dag_status on both
|
||||||
|
success and failure paths. These tests verify that integration rather than
|
||||||
|
testing broadcast_dag_status in isolation (which test_dag_progress.py covers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.hatchet.constants import TaskName
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithErrorHandlingBroadcast:
|
||||||
|
"""Test with_error_handling decorator's integration with broadcast_dag_status."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _setup_hatchet_mock(self):
|
||||||
|
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
|
||||||
|
|
||||||
|
Module-level code in workflow files calls get_client() during import.
|
||||||
|
Setting _instance before import avoids ClientConfig validation.
|
||||||
|
"""
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
|
original = HatchetClientManager._instance
|
||||||
|
HatchetClientManager._instance = MagicMock()
|
||||||
|
yield
|
||||||
|
HatchetClientManager._instance = original
|
||||||
|
|
||||||
|
def _make_input(self, transcript_id: str = "t-123") -> MagicMock:
|
||||||
|
"""Create a mock PipelineInput with transcript_id."""
|
||||||
|
inp = MagicMock()
|
||||||
|
inp.transcript_id = transcript_id
|
||||||
|
return inp
|
||||||
|
|
||||||
|
def _make_ctx(self, workflow_run_id: str = "wf-abc") -> MagicMock:
|
||||||
|
"""Create a mock Context with workflow_run_id."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.workflow_run_id = workflow_run_id
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_broadcast_on_success(self):
|
||||||
|
"""Decorator calls broadcast_dag_status once when task succeeds."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = AsyncMock(return_value="ok")
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_broadcast:
|
||||||
|
result = await wrapped(self._make_input(), self._make_ctx())
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
mock_broadcast.assert_called_once_with("t-123", "wf-abc")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_broadcast_on_failure(self):
|
||||||
|
"""Decorator calls broadcast_dag_status once when task raises."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_broadcast,
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
|
await wrapped(self._make_input(), self._make_ctx())
|
||||||
|
|
||||||
|
mock_broadcast.assert_called_once_with("t-123", "wf-abc")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swallows_broadcast_exception_on_success(self):
|
||||||
|
"""Broadcast failure does not crash the task on the success path."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = AsyncMock(return_value="ok")
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=RuntimeError("broadcast exploded"),
|
||||||
|
):
|
||||||
|
result = await wrapped(self._make_input(), self._make_ctx())
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swallows_broadcast_exception_on_failure(self):
|
||||||
|
"""Original task exception propagates even when broadcast also fails."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = AsyncMock(side_effect=ValueError("original error"))
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=RuntimeError("broadcast exploded"),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="original error"):
|
||||||
|
await wrapped(self._make_input(), self._make_ctx())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_set_workflow_error_status_on_failure(self):
|
||||||
|
"""On task failure with set_error_status=True (default), calls set_workflow_error_status."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error,
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
|
await wrapped(self._make_input(), self._make_ctx())
|
||||||
|
|
||||||
|
mock_set_error.assert_called_once_with("t-123")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_set_workflow_error_status_when_disabled(self):
|
||||||
|
"""With set_error_status=False, set_workflow_error_status is NOT called on failure."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
with_error_handling,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
wrapped = with_error_handling(TaskName.GET_RECORDING, set_error_status=False)(
|
||||||
|
inner
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_set_error,
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
|
await wrapped(self._make_input(), self._make_ctx())
|
||||||
|
|
||||||
|
mock_set_error.assert_not_called()
|
||||||
Reference in New Issue
Block a user