mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-25 16:36:46 +00:00
feat: add DagTask models and extract_dag_tasks transform
Foundation for DAG progress reporting to frontend. Ported topo sort and task extraction from render_hatchet_run.py (Zulip worktree) to produce structured Pydantic models instead of markdown.
This commit is contained in:
189
server/reflector/hatchet/dag_progress.py
Normal file
189
server/reflector/hatchet/dag_progress.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
DAG Progress Reporting — models and transform.
|
||||
|
||||
Converts Hatchet V1WorkflowRunDetails into structured DagTask list
|
||||
for frontend WebSocket/REST consumption.
|
||||
|
||||
Ported from render_hatchet_run.py (feat-dag-zulip) which renders markdown;
|
||||
this module produces structured Pydantic models instead.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from hatchet_sdk.clients.rest.models import (
|
||||
V1TaskStatus,
|
||||
V1WorkflowRunDetails,
|
||||
WorkflowRunShapeItemForWorkflowRunDetails,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DagTaskStatus(StrEnum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
_HATCHET_TO_DAG_STATUS: dict[V1TaskStatus, DagTaskStatus] = {
|
||||
V1TaskStatus.QUEUED: DagTaskStatus.QUEUED,
|
||||
V1TaskStatus.RUNNING: DagTaskStatus.RUNNING,
|
||||
V1TaskStatus.COMPLETED: DagTaskStatus.COMPLETED,
|
||||
V1TaskStatus.FAILED: DagTaskStatus.FAILED,
|
||||
V1TaskStatus.CANCELLED: DagTaskStatus.CANCELLED,
|
||||
}
|
||||
|
||||
|
||||
class DagTask(BaseModel):
|
||||
name: str
|
||||
status: DagTaskStatus
|
||||
started_at: datetime | None
|
||||
finished_at: datetime | None
|
||||
duration_seconds: float | None
|
||||
parents: list[str]
|
||||
error: str | None
|
||||
children_total: int | None
|
||||
children_completed: int | None
|
||||
progress_pct: float | None
|
||||
|
||||
|
||||
class DagStatusData(BaseModel):
|
||||
workflow_run_id: str
|
||||
tasks: list[DagTask]
|
||||
|
||||
|
||||
def _topo_sort(
|
||||
shape: list[WorkflowRunShapeItemForWorkflowRunDetails],
|
||||
) -> list[str]:
|
||||
"""Topological sort of step_ids from shape DAG (Kahn's algorithm).
|
||||
|
||||
Ported from render_hatchet_run.py.
|
||||
"""
|
||||
step_ids = {s.step_id for s in shape}
|
||||
children_map: dict[str, list[str]] = {}
|
||||
in_degree: dict[str, int] = {sid: 0 for sid in step_ids}
|
||||
|
||||
for s in shape:
|
||||
children = [c for c in (s.children_step_ids or []) if c in step_ids]
|
||||
children_map[s.step_id] = children
|
||||
for c in children:
|
||||
in_degree[c] += 1
|
||||
|
||||
queue = sorted(sid for sid, deg in in_degree.items() if deg == 0)
|
||||
result: list[str] = []
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
result.append(node)
|
||||
for c in children_map.get(node, []):
|
||||
in_degree[c] -= 1
|
||||
if in_degree[c] == 0:
|
||||
queue.append(c)
|
||||
queue.sort()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_error_summary(error_message: str | None) -> str | None:
|
||||
"""Extract first meaningful line from error message, skipping traceback frames."""
|
||||
if not error_message or not error_message.strip():
|
||||
return None
|
||||
|
||||
err_lines = error_message.strip().split("\n")
|
||||
err_summary = err_lines[0]
|
||||
for line in err_lines:
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith(("Traceback", "File ", "{", ")")):
|
||||
err_summary = stripped
|
||||
return err_summary
|
||||
|
||||
|
||||
def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]:
|
||||
"""Extract structured DagTask list from Hatchet workflow run details.
|
||||
|
||||
Returns tasks in topological order with status, timestamps, parents,
|
||||
error summaries, and fan-out children counts.
|
||||
"""
|
||||
shape = details.shape or []
|
||||
tasks = details.tasks or []
|
||||
|
||||
if not shape:
|
||||
return []
|
||||
|
||||
# Build lookups
|
||||
step_to_shape: dict[str, WorkflowRunShapeItemForWorkflowRunDetails] = {
|
||||
s.step_id: s for s in shape
|
||||
}
|
||||
step_to_name: dict[str, str] = {s.step_id: s.task_name for s in shape}
|
||||
|
||||
# Reverse edges: child -> parent names
|
||||
parents_by_step: dict[str, list[str]] = {s.step_id: [] for s in shape}
|
||||
for s in shape:
|
||||
for child_id in s.children_step_ids or []:
|
||||
if child_id in parents_by_step:
|
||||
parents_by_step[child_id].append(step_to_name[s.step_id])
|
||||
|
||||
# Join tasks by step_id
|
||||
from hatchet_sdk.clients.rest.models import V1TaskSummary
|
||||
|
||||
task_by_step: dict[str, V1TaskSummary] = {}
|
||||
for t in tasks:
|
||||
if t.step_id and t.step_id in step_to_name:
|
||||
task_by_step[t.step_id] = t
|
||||
|
||||
ordered = _topo_sort(shape)
|
||||
|
||||
result: list[DagTask] = []
|
||||
for step_id in ordered:
|
||||
name = step_to_name[step_id]
|
||||
t = task_by_step.get(step_id)
|
||||
|
||||
if not t:
|
||||
result.append(
|
||||
DagTask(
|
||||
name=name,
|
||||
status=DagTaskStatus.QUEUED,
|
||||
started_at=None,
|
||||
finished_at=None,
|
||||
duration_seconds=None,
|
||||
parents=parents_by_step.get(step_id, []),
|
||||
error=None,
|
||||
children_total=None,
|
||||
children_completed=None,
|
||||
progress_pct=None,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
status = _HATCHET_TO_DAG_STATUS.get(t.status, DagTaskStatus.QUEUED)
|
||||
|
||||
duration_seconds: float | None = None
|
||||
if t.duration is not None:
|
||||
duration_seconds = t.duration / 1000.0
|
||||
|
||||
# Fan-out children
|
||||
children_total: int | None = None
|
||||
children_completed: int | None = None
|
||||
if t.num_spawned_children and t.num_spawned_children > 0:
|
||||
children_total = t.num_spawned_children
|
||||
children_completed = sum(
|
||||
1 for c in (t.children or []) if c.status == V1TaskStatus.COMPLETED
|
||||
)
|
||||
|
||||
result.append(
|
||||
DagTask(
|
||||
name=name,
|
||||
status=status,
|
||||
started_at=t.started_at,
|
||||
finished_at=t.finished_at,
|
||||
duration_seconds=duration_seconds,
|
||||
parents=parents_by_step.get(step_id, []),
|
||||
error=_extract_error_summary(t.error_message),
|
||||
children_total=children_total,
|
||||
children_completed=children_completed,
|
||||
progress_pct=None,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
389
server/tests/test_dag_progress.py
Normal file
389
server/tests/test_dag_progress.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""Tests for DAG progress models and transform function.
|
||||
|
||||
Tests the extract_dag_tasks function that converts Hatchet V1WorkflowRunDetails
|
||||
into structured DagTask list for frontend consumption.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from reflector.hatchet.dag_progress import (
|
||||
DagStatusData,
|
||||
DagTask,
|
||||
DagTaskStatus,
|
||||
extract_dag_tasks,
|
||||
)
|
||||
|
||||
|
||||
def _make_shape_item(
|
||||
step_id: str,
|
||||
task_name: str,
|
||||
children_step_ids: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRunShapeItemForWorkflowRunDetails."""
|
||||
item = MagicMock()
|
||||
item.step_id = step_id
|
||||
item.task_name = task_name
|
||||
item.children_step_ids = children_step_ids or []
|
||||
return item
|
||||
|
||||
|
||||
def _make_task_summary(
|
||||
step_id: str,
|
||||
status: str = "QUEUED",
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
duration: int | None = None,
|
||||
error_message: str | None = None,
|
||||
task_external_id: str | None = None,
|
||||
num_spawned_children: int | None = None,
|
||||
children: list | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock V1TaskSummary."""
|
||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||
|
||||
task = MagicMock()
|
||||
task.step_id = step_id
|
||||
task.status = V1TaskStatus(status)
|
||||
task.started_at = started_at
|
||||
task.finished_at = finished_at
|
||||
task.duration = duration
|
||||
task.error_message = error_message
|
||||
task.task_external_id = task_external_id or f"ext-{step_id}"
|
||||
task.num_spawned_children = num_spawned_children
|
||||
task.children = children or []
|
||||
return task
|
||||
|
||||
|
||||
def _make_details(
|
||||
shape: list,
|
||||
tasks: list,
|
||||
run_id: str = "test-run-id",
|
||||
) -> MagicMock:
|
||||
"""Create a mock V1WorkflowRunDetails."""
|
||||
details = MagicMock()
|
||||
details.shape = shape
|
||||
details.tasks = tasks
|
||||
details.task_events = []
|
||||
details.run = MagicMock()
|
||||
details.run.metadata = MagicMock()
|
||||
details.run.metadata.id = run_id
|
||||
return details
|
||||
|
||||
|
||||
class TestExtractDagTasksBasic:
|
||||
"""Test basic extraction of DAG tasks from workflow run details."""
|
||||
|
||||
def test_empty_shape_returns_empty_list(self):
|
||||
details = _make_details(shape=[], tasks=[])
|
||||
result = extract_dag_tasks(details)
|
||||
assert result == []
|
||||
|
||||
def test_single_task_queued(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [_make_task_summary("s1", status="QUEUED")]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "get_recording"
|
||||
assert result[0].status == DagTaskStatus.QUEUED
|
||||
assert result[0].parents == []
|
||||
assert result[0].started_at is None
|
||||
assert result[0].finished_at is None
|
||||
assert result[0].duration_seconds is None
|
||||
assert result[0].error is None
|
||||
assert result[0].children_total is None
|
||||
assert result[0].children_completed is None
|
||||
assert result[0].progress_pct is None
|
||||
|
||||
def test_completed_task_with_duration(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="COMPLETED",
|
||||
started_at=now,
|
||||
finished_at=now,
|
||||
duration=1500, # milliseconds
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.COMPLETED
|
||||
assert result[0].duration_seconds == 1.5
|
||||
assert result[0].started_at == now
|
||||
assert result[0].finished_at == now
|
||||
|
||||
def test_failed_task_with_error(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="FAILED",
|
||||
error_message="Traceback (most recent call last):\n File something\nConnectionError: connection refused",
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.FAILED
|
||||
assert result[0].error == "ConnectionError: connection refused"
|
||||
|
||||
def test_running_task(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="RUNNING",
|
||||
started_at=now,
|
||||
duration=5000,
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.RUNNING
|
||||
assert result[0].started_at == now
|
||||
assert result[0].duration_seconds == 5.0
|
||||
|
||||
def test_cancelled_task(self):
|
||||
shape = [_make_shape_item("s1", "post_zulip")]
|
||||
tasks = [_make_task_summary("s1", status="CANCELLED")]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.CANCELLED
|
||||
|
||||
|
||||
class TestExtractDagTasksTopology:
|
||||
"""Test topological ordering and parent extraction."""
|
||||
|
||||
def test_linear_chain_parents(self):
|
||||
"""A -> B -> C should produce correct parents."""
|
||||
shape = [
|
||||
_make_shape_item("s1", "get_recording", children_step_ids=["s2"]),
|
||||
_make_shape_item("s2", "get_participants", children_step_ids=["s3"]),
|
||||
_make_shape_item("s3", "process_tracks"),
|
||||
]
|
||||
tasks = [
|
||||
_make_task_summary("s1", status="COMPLETED"),
|
||||
_make_task_summary("s2", status="COMPLETED"),
|
||||
_make_task_summary("s3", status="QUEUED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert [t.name for t in result] == [
|
||||
"get_recording",
|
||||
"get_participants",
|
||||
"process_tracks",
|
||||
]
|
||||
assert result[0].parents == []
|
||||
assert result[1].parents == ["get_recording"]
|
||||
assert result[2].parents == ["get_participants"]
|
||||
|
||||
def test_diamond_dag(self):
|
||||
"""
|
||||
A -> B, A -> C, B -> D, C -> D
|
||||
D should have parents [B, C] (or [C, B] depending on sort).
|
||||
"""
|
||||
shape = [
|
||||
_make_shape_item("s1", "get_recording", children_step_ids=["s2", "s3"]),
|
||||
_make_shape_item("s2", "mixdown_tracks", children_step_ids=["s4"]),
|
||||
_make_shape_item("s3", "detect_topics", children_step_ids=["s4"]),
|
||||
_make_shape_item("s4", "finalize"),
|
||||
]
|
||||
tasks = [
|
||||
_make_task_summary("s1", status="COMPLETED"),
|
||||
_make_task_summary("s2", status="RUNNING"),
|
||||
_make_task_summary("s3", status="RUNNING"),
|
||||
_make_task_summary("s4", status="QUEUED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
# Topological: s1 first, s2/s3 in some order, s4 last
|
||||
assert result[0].name == "get_recording"
|
||||
assert result[-1].name == "finalize"
|
||||
finalize = result[-1]
|
||||
assert set(finalize.parents) == {"mixdown_tracks", "detect_topics"}
|
||||
|
||||
def test_topological_order_is_stable(self):
|
||||
"""Verify deterministic ordering (sorted queue in Kahn's)."""
|
||||
shape = [
|
||||
_make_shape_item("s_c", "task_c"),
|
||||
_make_shape_item("s_a", "task_a", children_step_ids=["s_c"]),
|
||||
_make_shape_item("s_b", "task_b", children_step_ids=["s_c"]),
|
||||
]
|
||||
tasks = [
|
||||
_make_task_summary("s_c", status="QUEUED"),
|
||||
_make_task_summary("s_a", status="COMPLETED"),
|
||||
_make_task_summary("s_b", status="COMPLETED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
# s_a and s_b both roots with in-degree 0; sorted alphabetically by step_id
|
||||
names = [t.name for t in result]
|
||||
assert names[-1] == "task_c"
|
||||
# First two should be task_a, task_b (sorted by step_id: s_a < s_b)
|
||||
assert names[0] == "task_a"
|
||||
assert names[1] == "task_b"
|
||||
|
||||
|
||||
class TestExtractDagTasksFanOut:
|
||||
"""Test fan-out tasks with spawned children."""
|
||||
|
||||
def test_fan_out_children_counts(self):
|
||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||
|
||||
child_mocks = []
|
||||
for status in ["COMPLETED", "COMPLETED", "RUNNING", "QUEUED"]:
|
||||
child = MagicMock()
|
||||
child.status = V1TaskStatus(status)
|
||||
child_mocks.append(child)
|
||||
|
||||
shape = [_make_shape_item("s1", "process_tracks")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="RUNNING",
|
||||
num_spawned_children=4,
|
||||
children=child_mocks,
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].children_total == 4
|
||||
assert result[0].children_completed == 2
|
||||
|
||||
def test_no_children_when_no_spawn(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [
|
||||
_make_task_summary("s1", status="COMPLETED", num_spawned_children=None)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].children_total is None
|
||||
assert result[0].children_completed is None
|
||||
|
||||
def test_zero_spawned_children(self):
|
||||
shape = [_make_shape_item("s1", "process_tracks")]
|
||||
tasks = [_make_task_summary("s1", status="COMPLETED", num_spawned_children=0)]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].children_total is None
|
||||
assert result[0].children_completed is None
|
||||
|
||||
|
||||
class TestExtractDagTasksErrorExtraction:
|
||||
"""Test error message extraction logic."""
|
||||
|
||||
def test_simple_error(self):
|
||||
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1", status="FAILED", error_message="ValueError: no tracks"
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error == "ValueError: no tracks"
|
||||
|
||||
def test_traceback_extracts_meaningful_line(self):
|
||||
error = (
|
||||
"Traceback (most recent call last):\n"
|
||||
' File "/app/something.py", line 42\n'
|
||||
"RuntimeError: out of memory"
|
||||
)
|
||||
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
||||
tasks = [_make_task_summary("s1", status="FAILED", error_message=error)]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error == "RuntimeError: out of memory"
|
||||
|
||||
def test_no_error_when_none(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [_make_task_summary("s1", status="COMPLETED", error_message=None)]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error is None
|
||||
|
||||
def test_empty_error_message(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [_make_task_summary("s1", status="FAILED", error_message="")]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error is None
|
||||
|
||||
|
||||
class TestExtractDagTasksMissingData:
|
||||
"""Test edge cases with missing task data."""
|
||||
|
||||
def test_shape_without_matching_task(self):
|
||||
"""Shape has a step but tasks list doesn't contain it."""
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [] # No matching task
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "get_recording"
|
||||
assert result[0].status == DagTaskStatus.QUEUED # default when no task data
|
||||
assert result[0].started_at is None
|
||||
|
||||
def test_none_shape_returns_empty(self):
|
||||
details = _make_details(shape=[], tasks=[])
|
||||
details.shape = None
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestDagStatusData:
|
||||
"""Test DagStatusData model serialization."""
|
||||
|
||||
def test_serialization(self):
|
||||
task = DagTask(
|
||||
name="get_recording",
|
||||
status=DagTaskStatus.COMPLETED,
|
||||
started_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
finished_at=datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc),
|
||||
duration_seconds=1.0,
|
||||
parents=[],
|
||||
error=None,
|
||||
children_total=None,
|
||||
children_completed=None,
|
||||
progress_pct=None,
|
||||
)
|
||||
data = DagStatusData(workflow_run_id="test-123", tasks=[task])
|
||||
dumped = data.model_dump(mode="json")
|
||||
|
||||
assert dumped["workflow_run_id"] == "test-123"
|
||||
assert len(dumped["tasks"]) == 1
|
||||
assert dumped["tasks"][0]["name"] == "get_recording"
|
||||
assert dumped["tasks"][0]["status"] == "completed"
|
||||
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|
||||
Reference in New Issue
Block a user