mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-25 08:26:48 +00:00
Foundation for DAG progress reporting to frontend. Ported topo sort and task extraction from render_hatchet_run.py (Zulip worktree) to produce structured Pydantic models instead of markdown.
390 lines
13 KiB
Python
390 lines
13 KiB
Python
"""Tests for DAG progress models and transform function.
|
|
|
|
Tests the extract_dag_tasks function that converts Hatchet V1WorkflowRunDetails
|
|
into structured DagTask list for frontend consumption.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock
|
|
|
|
from reflector.hatchet.dag_progress import (
|
|
DagStatusData,
|
|
DagTask,
|
|
DagTaskStatus,
|
|
extract_dag_tasks,
|
|
)
|
|
|
|
|
|
def _make_shape_item(
|
|
step_id: str,
|
|
task_name: str,
|
|
children_step_ids: list[str] | None = None,
|
|
) -> MagicMock:
|
|
"""Create a mock WorkflowRunShapeItemForWorkflowRunDetails."""
|
|
item = MagicMock()
|
|
item.step_id = step_id
|
|
item.task_name = task_name
|
|
item.children_step_ids = children_step_ids or []
|
|
return item
|
|
|
|
|
|
def _make_task_summary(
|
|
step_id: str,
|
|
status: str = "QUEUED",
|
|
started_at: datetime | None = None,
|
|
finished_at: datetime | None = None,
|
|
duration: int | None = None,
|
|
error_message: str | None = None,
|
|
task_external_id: str | None = None,
|
|
num_spawned_children: int | None = None,
|
|
children: list | None = None,
|
|
) -> MagicMock:
|
|
"""Create a mock V1TaskSummary."""
|
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
|
|
|
task = MagicMock()
|
|
task.step_id = step_id
|
|
task.status = V1TaskStatus(status)
|
|
task.started_at = started_at
|
|
task.finished_at = finished_at
|
|
task.duration = duration
|
|
task.error_message = error_message
|
|
task.task_external_id = task_external_id or f"ext-{step_id}"
|
|
task.num_spawned_children = num_spawned_children
|
|
task.children = children or []
|
|
return task
|
|
|
|
|
|
def _make_details(
|
|
shape: list,
|
|
tasks: list,
|
|
run_id: str = "test-run-id",
|
|
) -> MagicMock:
|
|
"""Create a mock V1WorkflowRunDetails."""
|
|
details = MagicMock()
|
|
details.shape = shape
|
|
details.tasks = tasks
|
|
details.task_events = []
|
|
details.run = MagicMock()
|
|
details.run.metadata = MagicMock()
|
|
details.run.metadata.id = run_id
|
|
return details
|
|
|
|
|
|
class TestExtractDagTasksBasic:
|
|
"""Test basic extraction of DAG tasks from workflow run details."""
|
|
|
|
def test_empty_shape_returns_empty_list(self):
|
|
details = _make_details(shape=[], tasks=[])
|
|
result = extract_dag_tasks(details)
|
|
assert result == []
|
|
|
|
def test_single_task_queued(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [_make_task_summary("s1", status="QUEUED")]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].name == "get_recording"
|
|
assert result[0].status == DagTaskStatus.QUEUED
|
|
assert result[0].parents == []
|
|
assert result[0].started_at is None
|
|
assert result[0].finished_at is None
|
|
assert result[0].duration_seconds is None
|
|
assert result[0].error is None
|
|
assert result[0].children_total is None
|
|
assert result[0].children_completed is None
|
|
assert result[0].progress_pct is None
|
|
|
|
def test_completed_task_with_duration(self):
|
|
now = datetime.now(timezone.utc)
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="COMPLETED",
|
|
started_at=now,
|
|
finished_at=now,
|
|
duration=1500, # milliseconds
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.COMPLETED
|
|
assert result[0].duration_seconds == 1.5
|
|
assert result[0].started_at == now
|
|
assert result[0].finished_at == now
|
|
|
|
def test_failed_task_with_error(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="FAILED",
|
|
error_message="Traceback (most recent call last):\n File something\nConnectionError: connection refused",
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.FAILED
|
|
assert result[0].error == "ConnectionError: connection refused"
|
|
|
|
def test_running_task(self):
|
|
now = datetime.now(timezone.utc)
|
|
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="RUNNING",
|
|
started_at=now,
|
|
duration=5000,
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.RUNNING
|
|
assert result[0].started_at == now
|
|
assert result[0].duration_seconds == 5.0
|
|
|
|
def test_cancelled_task(self):
|
|
shape = [_make_shape_item("s1", "post_zulip")]
|
|
tasks = [_make_task_summary("s1", status="CANCELLED")]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.CANCELLED
|
|
|
|
|
|
class TestExtractDagTasksTopology:
|
|
"""Test topological ordering and parent extraction."""
|
|
|
|
def test_linear_chain_parents(self):
|
|
"""A -> B -> C should produce correct parents."""
|
|
shape = [
|
|
_make_shape_item("s1", "get_recording", children_step_ids=["s2"]),
|
|
_make_shape_item("s2", "get_participants", children_step_ids=["s3"]),
|
|
_make_shape_item("s3", "process_tracks"),
|
|
]
|
|
tasks = [
|
|
_make_task_summary("s1", status="COMPLETED"),
|
|
_make_task_summary("s2", status="COMPLETED"),
|
|
_make_task_summary("s3", status="QUEUED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert [t.name for t in result] == [
|
|
"get_recording",
|
|
"get_participants",
|
|
"process_tracks",
|
|
]
|
|
assert result[0].parents == []
|
|
assert result[1].parents == ["get_recording"]
|
|
assert result[2].parents == ["get_participants"]
|
|
|
|
def test_diamond_dag(self):
|
|
"""
|
|
A -> B, A -> C, B -> D, C -> D
|
|
D should have parents [B, C] (or [C, B] depending on sort).
|
|
"""
|
|
shape = [
|
|
_make_shape_item("s1", "get_recording", children_step_ids=["s2", "s3"]),
|
|
_make_shape_item("s2", "mixdown_tracks", children_step_ids=["s4"]),
|
|
_make_shape_item("s3", "detect_topics", children_step_ids=["s4"]),
|
|
_make_shape_item("s4", "finalize"),
|
|
]
|
|
tasks = [
|
|
_make_task_summary("s1", status="COMPLETED"),
|
|
_make_task_summary("s2", status="RUNNING"),
|
|
_make_task_summary("s3", status="RUNNING"),
|
|
_make_task_summary("s4", status="QUEUED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
# Topological: s1 first, s2/s3 in some order, s4 last
|
|
assert result[0].name == "get_recording"
|
|
assert result[-1].name == "finalize"
|
|
finalize = result[-1]
|
|
assert set(finalize.parents) == {"mixdown_tracks", "detect_topics"}
|
|
|
|
def test_topological_order_is_stable(self):
|
|
"""Verify deterministic ordering (sorted queue in Kahn's)."""
|
|
shape = [
|
|
_make_shape_item("s_c", "task_c"),
|
|
_make_shape_item("s_a", "task_a", children_step_ids=["s_c"]),
|
|
_make_shape_item("s_b", "task_b", children_step_ids=["s_c"]),
|
|
]
|
|
tasks = [
|
|
_make_task_summary("s_c", status="QUEUED"),
|
|
_make_task_summary("s_a", status="COMPLETED"),
|
|
_make_task_summary("s_b", status="COMPLETED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
# s_a and s_b both roots with in-degree 0; sorted alphabetically by step_id
|
|
names = [t.name for t in result]
|
|
assert names[-1] == "task_c"
|
|
# First two should be task_a, task_b (sorted by step_id: s_a < s_b)
|
|
assert names[0] == "task_a"
|
|
assert names[1] == "task_b"
|
|
|
|
|
|
class TestExtractDagTasksFanOut:
|
|
"""Test fan-out tasks with spawned children."""
|
|
|
|
def test_fan_out_children_counts(self):
|
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
|
|
|
child_mocks = []
|
|
for status in ["COMPLETED", "COMPLETED", "RUNNING", "QUEUED"]:
|
|
child = MagicMock()
|
|
child.status = V1TaskStatus(status)
|
|
child_mocks.append(child)
|
|
|
|
shape = [_make_shape_item("s1", "process_tracks")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="RUNNING",
|
|
num_spawned_children=4,
|
|
children=child_mocks,
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].children_total == 4
|
|
assert result[0].children_completed == 2
|
|
|
|
def test_no_children_when_no_spawn(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [
|
|
_make_task_summary("s1", status="COMPLETED", num_spawned_children=None)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].children_total is None
|
|
assert result[0].children_completed is None
|
|
|
|
def test_zero_spawned_children(self):
|
|
shape = [_make_shape_item("s1", "process_tracks")]
|
|
tasks = [_make_task_summary("s1", status="COMPLETED", num_spawned_children=0)]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].children_total is None
|
|
assert result[0].children_completed is None
|
|
|
|
|
|
class TestExtractDagTasksErrorExtraction:
|
|
"""Test error message extraction logic."""
|
|
|
|
def test_simple_error(self):
|
|
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1", status="FAILED", error_message="ValueError: no tracks"
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error == "ValueError: no tracks"
|
|
|
|
def test_traceback_extracts_meaningful_line(self):
|
|
error = (
|
|
"Traceback (most recent call last):\n"
|
|
' File "/app/something.py", line 42\n'
|
|
"RuntimeError: out of memory"
|
|
)
|
|
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
|
tasks = [_make_task_summary("s1", status="FAILED", error_message=error)]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error == "RuntimeError: out of memory"
|
|
|
|
def test_no_error_when_none(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [_make_task_summary("s1", status="COMPLETED", error_message=None)]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error is None
|
|
|
|
def test_empty_error_message(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [_make_task_summary("s1", status="FAILED", error_message="")]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error is None
|
|
|
|
|
|
class TestExtractDagTasksMissingData:
|
|
"""Test edge cases with missing task data."""
|
|
|
|
def test_shape_without_matching_task(self):
|
|
"""Shape has a step but tasks list doesn't contain it."""
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [] # No matching task
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].name == "get_recording"
|
|
assert result[0].status == DagTaskStatus.QUEUED # default when no task data
|
|
assert result[0].started_at is None
|
|
|
|
def test_none_shape_returns_empty(self):
|
|
details = _make_details(shape=[], tasks=[])
|
|
details.shape = None
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result == []
|
|
|
|
|
|
class TestDagStatusData:
|
|
"""Test DagStatusData model serialization."""
|
|
|
|
def test_serialization(self):
|
|
task = DagTask(
|
|
name="get_recording",
|
|
status=DagTaskStatus.COMPLETED,
|
|
started_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
finished_at=datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc),
|
|
duration_seconds=1.0,
|
|
parents=[],
|
|
error=None,
|
|
children_total=None,
|
|
children_completed=None,
|
|
progress_pct=None,
|
|
)
|
|
data = DagStatusData(workflow_run_id="test-123", tasks=[task])
|
|
dumped = data.model_dump(mode="json")
|
|
|
|
assert dumped["workflow_run_id"] == "test-123"
|
|
assert len(dumped["tasks"]) == 1
|
|
assert dumped["tasks"][0]["name"] == "get_recording"
|
|
assert dumped["tasks"][0]["status"] == "completed"
|
|
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|