mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-22 07:06:47 +00:00
960 lines
35 KiB
Python
960 lines
35 KiB
Python
"""Tests for DAG progress models and transform function.
|
|
|
|
Tests the extract_dag_tasks function that converts Hatchet V1WorkflowRunDetails
|
|
into structured DagTask list for frontend consumption.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from reflector.hatchet.constants import TaskName
|
|
from reflector.hatchet.dag_progress import (
|
|
DagStatusData,
|
|
DagTask,
|
|
DagTaskStatus,
|
|
extract_dag_tasks,
|
|
)
|
|
|
|
|
|
def _make_shape_item(
|
|
step_id: str,
|
|
task_name: str,
|
|
children_step_ids: list[str] | None = None,
|
|
) -> MagicMock:
|
|
"""Create a mock WorkflowRunShapeItemForWorkflowRunDetails."""
|
|
item = MagicMock()
|
|
item.step_id = step_id
|
|
item.task_name = task_name
|
|
item.children_step_ids = children_step_ids or []
|
|
return item
|
|
|
|
|
|
def _make_task_summary(
|
|
step_id: str,
|
|
status: str = "QUEUED",
|
|
started_at: datetime | None = None,
|
|
finished_at: datetime | None = None,
|
|
duration: int | None = None,
|
|
error_message: str | None = None,
|
|
task_external_id: str | None = None,
|
|
num_spawned_children: int | None = None,
|
|
children: list | None = None,
|
|
) -> MagicMock:
|
|
"""Create a mock V1TaskSummary."""
|
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
|
|
|
task = MagicMock()
|
|
task.step_id = step_id
|
|
task.status = V1TaskStatus(status)
|
|
task.started_at = started_at
|
|
task.finished_at = finished_at
|
|
task.duration = duration
|
|
task.error_message = error_message
|
|
task.task_external_id = task_external_id or f"ext-{step_id}"
|
|
task.num_spawned_children = num_spawned_children
|
|
task.children = children or []
|
|
return task
|
|
|
|
|
|
def _make_details(
|
|
shape: list,
|
|
tasks: list,
|
|
run_id: str = "test-run-id",
|
|
) -> MagicMock:
|
|
"""Create a mock V1WorkflowRunDetails."""
|
|
details = MagicMock()
|
|
details.shape = shape
|
|
details.tasks = tasks
|
|
details.task_events = []
|
|
details.run = MagicMock()
|
|
details.run.metadata = MagicMock()
|
|
details.run.metadata.id = run_id
|
|
return details
|
|
|
|
|
|
class TestExtractDagTasksBasic:
|
|
"""Test basic extraction of DAG tasks from workflow run details."""
|
|
|
|
def test_empty_shape_returns_empty_list(self):
|
|
details = _make_details(shape=[], tasks=[])
|
|
result = extract_dag_tasks(details)
|
|
assert result == []
|
|
|
|
def test_single_task_queued(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [_make_task_summary("s1", status="QUEUED")]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].name == "get_recording"
|
|
assert result[0].status == DagTaskStatus.QUEUED
|
|
assert result[0].parents == []
|
|
assert result[0].started_at is None
|
|
assert result[0].finished_at is None
|
|
assert result[0].duration_seconds is None
|
|
assert result[0].error is None
|
|
assert result[0].children_total is None
|
|
assert result[0].children_completed is None
|
|
assert result[0].progress_pct is None
|
|
|
|
def test_completed_task_with_duration(self):
|
|
now = datetime.now(timezone.utc)
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="COMPLETED",
|
|
started_at=now,
|
|
finished_at=now,
|
|
duration=1500, # milliseconds
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.COMPLETED
|
|
assert result[0].duration_seconds == 1.5
|
|
assert result[0].started_at == now
|
|
assert result[0].finished_at == now
|
|
|
|
def test_failed_task_with_error(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="FAILED",
|
|
error_message="Traceback (most recent call last):\n File something\nConnectionError: connection refused",
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.FAILED
|
|
assert result[0].error == "ConnectionError: connection refused"
|
|
|
|
def test_running_task(self):
|
|
now = datetime.now(timezone.utc)
|
|
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="RUNNING",
|
|
started_at=now,
|
|
duration=5000,
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.RUNNING
|
|
assert result[0].started_at == now
|
|
assert result[0].duration_seconds == 5.0
|
|
|
|
def test_cancelled_task(self):
|
|
shape = [_make_shape_item("s1", "post_zulip")]
|
|
tasks = [_make_task_summary("s1", status="CANCELLED")]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].status == DagTaskStatus.CANCELLED
|
|
|
|
|
|
class TestExtractDagTasksTopology:
|
|
"""Test topological ordering and parent extraction."""
|
|
|
|
def test_linear_chain_parents(self):
|
|
"""A -> B -> C should produce correct parents."""
|
|
shape = [
|
|
_make_shape_item("s1", "get_recording", children_step_ids=["s2"]),
|
|
_make_shape_item("s2", "get_participants", children_step_ids=["s3"]),
|
|
_make_shape_item("s3", "process_tracks"),
|
|
]
|
|
tasks = [
|
|
_make_task_summary("s1", status="COMPLETED"),
|
|
_make_task_summary("s2", status="COMPLETED"),
|
|
_make_task_summary("s3", status="QUEUED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert [t.name for t in result] == [
|
|
"get_recording",
|
|
"get_participants",
|
|
"process_tracks",
|
|
]
|
|
assert result[0].parents == []
|
|
assert result[1].parents == ["get_recording"]
|
|
assert result[2].parents == ["get_participants"]
|
|
|
|
def test_diamond_dag(self):
|
|
"""
|
|
A -> B, A -> C, B -> D, C -> D
|
|
D should have parents [B, C] (or [C, B] depending on sort).
|
|
"""
|
|
shape = [
|
|
_make_shape_item("s1", "get_recording", children_step_ids=["s2", "s3"]),
|
|
_make_shape_item("s2", "mixdown_tracks", children_step_ids=["s4"]),
|
|
_make_shape_item("s3", "detect_topics", children_step_ids=["s4"]),
|
|
_make_shape_item("s4", "finalize"),
|
|
]
|
|
tasks = [
|
|
_make_task_summary("s1", status="COMPLETED"),
|
|
_make_task_summary("s2", status="RUNNING"),
|
|
_make_task_summary("s3", status="RUNNING"),
|
|
_make_task_summary("s4", status="QUEUED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
# Topological: s1 first, s2/s3 in some order, s4 last
|
|
assert result[0].name == "get_recording"
|
|
assert result[-1].name == "finalize"
|
|
finalize = result[-1]
|
|
assert set(finalize.parents) == {"mixdown_tracks", "detect_topics"}
|
|
|
|
def test_topological_order_is_stable(self):
|
|
"""Verify deterministic ordering (sorted queue in Kahn's)."""
|
|
shape = [
|
|
_make_shape_item("s_c", "task_c"),
|
|
_make_shape_item("s_a", "task_a", children_step_ids=["s_c"]),
|
|
_make_shape_item("s_b", "task_b", children_step_ids=["s_c"]),
|
|
]
|
|
tasks = [
|
|
_make_task_summary("s_c", status="QUEUED"),
|
|
_make_task_summary("s_a", status="COMPLETED"),
|
|
_make_task_summary("s_b", status="COMPLETED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
# s_a and s_b both roots with in-degree 0; sorted alphabetically by step_id
|
|
names = [t.name for t in result]
|
|
assert names[-1] == "task_c"
|
|
# First two should be task_a, task_b (sorted by step_id: s_a < s_b)
|
|
assert names[0] == "task_a"
|
|
assert names[1] == "task_b"
|
|
|
|
def test_production_dag_shape(self):
|
|
"""Test the real 15-task pipeline topology with mixed statuses.
|
|
|
|
Simulates a mid-pipeline state where early tasks completed,
|
|
middle tasks running, and later tasks still queued.
|
|
"""
|
|
# Production DAG edges (parent -> children):
|
|
# get_recording -> get_participants
|
|
# get_participants -> process_tracks
|
|
# process_tracks -> mixdown_tracks, detect_topics, finalize
|
|
# mixdown_tracks -> generate_waveform
|
|
# detect_topics -> generate_title, extract_subjects
|
|
# extract_subjects -> process_subjects, identify_action_items
|
|
# process_subjects -> generate_recap
|
|
# generate_title -> finalize
|
|
# generate_recap -> finalize
|
|
# identify_action_items -> finalize
|
|
# finalize -> cleanup_consent
|
|
# cleanup_consent -> post_zulip, send_webhook
|
|
shape = [
|
|
_make_shape_item(
|
|
"s_get_recording", TaskName.GET_RECORDING, ["s_get_participants"]
|
|
),
|
|
_make_shape_item(
|
|
"s_get_participants", TaskName.GET_PARTICIPANTS, ["s_process_tracks"]
|
|
),
|
|
_make_shape_item(
|
|
"s_process_tracks",
|
|
TaskName.PROCESS_TRACKS,
|
|
["s_mixdown_tracks", "s_detect_topics", "s_finalize"],
|
|
),
|
|
_make_shape_item(
|
|
"s_mixdown_tracks", TaskName.MIXDOWN_TRACKS, ["s_generate_waveform"]
|
|
),
|
|
_make_shape_item("s_generate_waveform", TaskName.GENERATE_WAVEFORM),
|
|
_make_shape_item(
|
|
"s_detect_topics",
|
|
TaskName.DETECT_TOPICS,
|
|
["s_generate_title", "s_extract_subjects"],
|
|
),
|
|
_make_shape_item(
|
|
"s_generate_title", TaskName.GENERATE_TITLE, ["s_finalize"]
|
|
),
|
|
_make_shape_item(
|
|
"s_extract_subjects",
|
|
TaskName.EXTRACT_SUBJECTS,
|
|
["s_process_subjects", "s_identify_action_items"],
|
|
),
|
|
_make_shape_item(
|
|
"s_process_subjects", TaskName.PROCESS_SUBJECTS, ["s_generate_recap"]
|
|
),
|
|
_make_shape_item(
|
|
"s_generate_recap", TaskName.GENERATE_RECAP, ["s_finalize"]
|
|
),
|
|
_make_shape_item(
|
|
"s_identify_action_items",
|
|
TaskName.IDENTIFY_ACTION_ITEMS,
|
|
["s_finalize"],
|
|
),
|
|
_make_shape_item("s_finalize", TaskName.FINALIZE, ["s_cleanup_consent"]),
|
|
_make_shape_item(
|
|
"s_cleanup_consent",
|
|
TaskName.CLEANUP_CONSENT,
|
|
["s_post_zulip", "s_send_webhook"],
|
|
),
|
|
_make_shape_item("s_post_zulip", TaskName.POST_ZULIP),
|
|
_make_shape_item("s_send_webhook", TaskName.SEND_WEBHOOK),
|
|
]
|
|
|
|
# Mid-pipeline: early tasks done, middle running, later queued
|
|
tasks = [
|
|
_make_task_summary("s_get_recording", status="COMPLETED"),
|
|
_make_task_summary("s_get_participants", status="COMPLETED"),
|
|
_make_task_summary("s_process_tracks", status="COMPLETED"),
|
|
_make_task_summary("s_mixdown_tracks", status="RUNNING"),
|
|
_make_task_summary("s_generate_waveform", status="QUEUED"),
|
|
_make_task_summary("s_detect_topics", status="RUNNING"),
|
|
_make_task_summary("s_generate_title", status="QUEUED"),
|
|
_make_task_summary("s_extract_subjects", status="QUEUED"),
|
|
_make_task_summary("s_process_subjects", status="QUEUED"),
|
|
_make_task_summary("s_generate_recap", status="QUEUED"),
|
|
_make_task_summary("s_identify_action_items", status="QUEUED"),
|
|
_make_task_summary("s_finalize", status="QUEUED"),
|
|
_make_task_summary("s_cleanup_consent", status="QUEUED"),
|
|
_make_task_summary("s_post_zulip", status="QUEUED"),
|
|
_make_task_summary("s_send_webhook", status="QUEUED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
# All 15 tasks present
|
|
assert len(result) == 15
|
|
result_names = [t.name for t in result]
|
|
assert set(result_names) == {
|
|
TaskName.GET_RECORDING,
|
|
TaskName.GET_PARTICIPANTS,
|
|
TaskName.PROCESS_TRACKS,
|
|
TaskName.MIXDOWN_TRACKS,
|
|
TaskName.GENERATE_WAVEFORM,
|
|
TaskName.DETECT_TOPICS,
|
|
TaskName.GENERATE_TITLE,
|
|
TaskName.EXTRACT_SUBJECTS,
|
|
TaskName.PROCESS_SUBJECTS,
|
|
TaskName.GENERATE_RECAP,
|
|
TaskName.IDENTIFY_ACTION_ITEMS,
|
|
TaskName.FINALIZE,
|
|
TaskName.CLEANUP_CONSENT,
|
|
TaskName.POST_ZULIP,
|
|
TaskName.SEND_WEBHOOK,
|
|
}
|
|
|
|
# Topological order invariant: no task appears before its parents
|
|
name_to_index = {t.name: i for i, t in enumerate(result)}
|
|
for task in result:
|
|
for parent_name in task.parents:
|
|
assert name_to_index[parent_name] < name_to_index[task.name], (
|
|
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
|
|
f"must appear before {task.name} (idx {name_to_index[task.name]})"
|
|
)
|
|
|
|
# finalize has exactly 4 parents
|
|
finalize = next(t for t in result if t.name == TaskName.FINALIZE)
|
|
assert set(finalize.parents) == {
|
|
TaskName.PROCESS_TRACKS,
|
|
TaskName.GENERATE_TITLE,
|
|
TaskName.GENERATE_RECAP,
|
|
TaskName.IDENTIFY_ACTION_ITEMS,
|
|
}
|
|
|
|
# cleanup_consent has 1 parent (finalize)
|
|
cleanup = next(t for t in result if t.name == TaskName.CLEANUP_CONSENT)
|
|
assert cleanup.parents == [TaskName.FINALIZE]
|
|
|
|
# post_zulip and send_webhook both have cleanup_consent as parent
|
|
post_zulip = next(t for t in result if t.name == TaskName.POST_ZULIP)
|
|
send_webhook = next(t for t in result if t.name == TaskName.SEND_WEBHOOK)
|
|
assert post_zulip.parents == [TaskName.CLEANUP_CONSENT]
|
|
assert send_webhook.parents == [TaskName.CLEANUP_CONSENT]
|
|
|
|
# Verify statuses propagated correctly
|
|
assert (
|
|
next(t for t in result if t.name == TaskName.GET_RECORDING).status
|
|
== DagTaskStatus.COMPLETED
|
|
)
|
|
assert (
|
|
next(t for t in result if t.name == TaskName.MIXDOWN_TRACKS).status
|
|
== DagTaskStatus.RUNNING
|
|
)
|
|
assert (
|
|
next(t for t in result if t.name == TaskName.FINALIZE).status
|
|
== DagTaskStatus.QUEUED
|
|
)
|
|
|
|
def test_topological_sort_invariant_complex_dag(self):
|
|
"""For a complex DAG, every task's parents appear earlier in the list.
|
|
|
|
Uses a wider branching/merging DAG than diamond to stress the invariant.
|
|
"""
|
|
# DAG: A -> B, A -> C, A -> D, B -> E, C -> E, C -> F, D -> F, E -> G, F -> G
|
|
shape = [
|
|
_make_shape_item("s_a", "task_a", ["s_b", "s_c", "s_d"]),
|
|
_make_shape_item("s_b", "task_b", ["s_e"]),
|
|
_make_shape_item("s_c", "task_c", ["s_e", "s_f"]),
|
|
_make_shape_item("s_d", "task_d", ["s_f"]),
|
|
_make_shape_item("s_e", "task_e", ["s_g"]),
|
|
_make_shape_item("s_f", "task_f", ["s_g"]),
|
|
_make_shape_item("s_g", "task_g"),
|
|
]
|
|
tasks = [
|
|
_make_task_summary("s_a", status="COMPLETED"),
|
|
_make_task_summary("s_b", status="COMPLETED"),
|
|
_make_task_summary("s_c", status="RUNNING"),
|
|
_make_task_summary("s_d", status="COMPLETED"),
|
|
_make_task_summary("s_e", status="QUEUED"),
|
|
_make_task_summary("s_f", status="QUEUED"),
|
|
_make_task_summary("s_g", status="QUEUED"),
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert len(result) == 7
|
|
name_to_index = {t.name: i for i, t in enumerate(result)}
|
|
|
|
# Verify invariant: every parent appears before its child
|
|
for task in result:
|
|
for parent_name in task.parents:
|
|
assert name_to_index[parent_name] < name_to_index[task.name], (
|
|
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
|
|
f"must appear before {task.name} (idx {name_to_index[task.name]})"
|
|
)
|
|
|
|
# task_g has 2 parents
|
|
task_g = next(t for t in result if t.name == "task_g")
|
|
assert set(task_g.parents) == {"task_e", "task_f"}
|
|
|
|
# task_e has 2 parents
|
|
task_e = next(t for t in result if t.name == "task_e")
|
|
assert set(task_e.parents) == {"task_b", "task_c"}
|
|
|
|
# task_a is root (first in topological order)
|
|
assert result[0].name == "task_a"
|
|
assert result[0].parents == []
|
|
|
|
|
|
class TestExtractDagTasksFanOut:
|
|
"""Test fan-out tasks with spawned children."""
|
|
|
|
def test_fan_out_children_counts(self):
|
|
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
|
|
|
child_mocks = []
|
|
for status in ["COMPLETED", "COMPLETED", "RUNNING", "QUEUED"]:
|
|
child = MagicMock()
|
|
child.status = V1TaskStatus(status)
|
|
child_mocks.append(child)
|
|
|
|
shape = [_make_shape_item("s1", "process_tracks")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1",
|
|
status="RUNNING",
|
|
num_spawned_children=4,
|
|
children=child_mocks,
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].children_total == 4
|
|
assert result[0].children_completed == 2
|
|
|
|
def test_no_children_when_no_spawn(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [
|
|
_make_task_summary("s1", status="COMPLETED", num_spawned_children=None)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].children_total is None
|
|
assert result[0].children_completed is None
|
|
|
|
def test_zero_spawned_children(self):
|
|
shape = [_make_shape_item("s1", "process_tracks")]
|
|
tasks = [_make_task_summary("s1", status="COMPLETED", num_spawned_children=0)]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert result[0].children_total is None
|
|
assert result[0].children_completed is None
|
|
|
|
|
|
class TestExtractDagTasksErrorExtraction:
|
|
"""Test error message extraction logic."""
|
|
|
|
def test_simple_error(self):
|
|
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
|
tasks = [
|
|
_make_task_summary(
|
|
"s1", status="FAILED", error_message="ValueError: no tracks"
|
|
)
|
|
]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error == "ValueError: no tracks"
|
|
|
|
def test_traceback_extracts_meaningful_line(self):
|
|
error = (
|
|
"Traceback (most recent call last):\n"
|
|
' File "/app/something.py", line 42\n'
|
|
"RuntimeError: out of memory"
|
|
)
|
|
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
|
tasks = [_make_task_summary("s1", status="FAILED", error_message=error)]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error == "RuntimeError: out of memory"
|
|
|
|
def test_no_error_when_none(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [_make_task_summary("s1", status="COMPLETED", error_message=None)]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error is None
|
|
|
|
def test_empty_error_message(self):
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [_make_task_summary("s1", status="FAILED", error_message="")]
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result[0].error is None
|
|
|
|
|
|
class TestExtractDagTasksMissingData:
|
|
"""Test edge cases with missing task data."""
|
|
|
|
def test_shape_without_matching_task(self):
|
|
"""Shape has a step but tasks list doesn't contain it."""
|
|
shape = [_make_shape_item("s1", "get_recording")]
|
|
tasks = [] # No matching task
|
|
details = _make_details(shape, tasks)
|
|
|
|
result = extract_dag_tasks(details)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].name == "get_recording"
|
|
assert result[0].status == DagTaskStatus.QUEUED # default when no task data
|
|
assert result[0].started_at is None
|
|
|
|
def test_none_shape_returns_empty(self):
|
|
details = _make_details(shape=[], tasks=[])
|
|
details.shape = None
|
|
|
|
result = extract_dag_tasks(details)
|
|
assert result == []
|
|
|
|
|
|
class TestDagStatusData:
|
|
"""Test DagStatusData model serialization."""
|
|
|
|
def test_serialization(self):
|
|
task = DagTask(
|
|
name="get_recording",
|
|
status=DagTaskStatus.COMPLETED,
|
|
started_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
finished_at=datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc),
|
|
duration_seconds=1.0,
|
|
parents=[],
|
|
error=None,
|
|
children_total=None,
|
|
children_completed=None,
|
|
progress_pct=None,
|
|
)
|
|
data = DagStatusData(workflow_run_id="test-123", tasks=[task])
|
|
dumped = data.model_dump(mode="json")
|
|
|
|
assert dumped["workflow_run_id"] == "test-123"
|
|
assert len(dumped["tasks"]) == 1
|
|
assert dumped["tasks"][0]["name"] == "get_recording"
|
|
assert dumped["tasks"][0]["status"] == "completed"
|
|
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|
|
|
|
|
|
class AsyncContextManager:
|
|
"""No-op async context manager for mocking fresh_db_connection."""
|
|
|
|
async def __aenter__(self):
|
|
return None
|
|
|
|
async def __aexit__(self, *args):
|
|
return None
|
|
|
|
|
|
class TestBroadcastDagStatus:
|
|
"""Test broadcast_dag_status function.
|
|
|
|
broadcast_dag_status uses deferred imports inside its function body.
|
|
We mock the source modules/objects before calling the function.
|
|
Importing daily_multitrack_pipeline triggers a cascade
|
|
(subject_processing -> HatchetClientManager.get_client at module level),
|
|
so we set _instance before the import to prevent real SDK init.
|
|
"""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _setup_hatchet_mock(self):
|
|
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
|
|
|
|
Module-level code in workflow files calls get_client() during import.
|
|
Setting _instance before import avoids ClientConfig validation.
|
|
"""
|
|
from reflector.hatchet.client import HatchetClientManager
|
|
|
|
original = HatchetClientManager._instance
|
|
HatchetClientManager._instance = MagicMock()
|
|
yield
|
|
HatchetClientManager._instance = original
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcasts_dag_status(self):
|
|
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
|
|
mock_transcript = MagicMock()
|
|
mock_transcript.id = "t-123"
|
|
|
|
mock_details = _make_details(
|
|
shape=[_make_shape_item("s1", "get_recording")],
|
|
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
|
run_id="wf-abc",
|
|
)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
|
|
|
with (
|
|
patch(
|
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
|
return_value=mock_client,
|
|
),
|
|
patch(
|
|
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
|
new_callable=AsyncMock,
|
|
) as mock_broadcast,
|
|
patch(
|
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_transcript,
|
|
),
|
|
patch(
|
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
|
return_value=AsyncContextManager(),
|
|
),
|
|
):
|
|
from reflector.hatchet.dag_progress import broadcast_dag_status
|
|
|
|
await broadcast_dag_status("t-123", "wf-abc")
|
|
|
|
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
|
|
mock_broadcast.assert_called_once()
|
|
call_args = mock_broadcast.call_args
|
|
assert call_args[0][0] == "t-123" # transcript_id
|
|
assert call_args[0][1] is mock_transcript # transcript
|
|
assert call_args[0][2] == "DAG_STATUS" # event_name
|
|
data = call_args[0][3]
|
|
assert isinstance(data, DagStatusData)
|
|
assert data.workflow_run_id == "wf-abc"
|
|
assert len(data.tasks) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swallows_exceptions(self):
|
|
"""broadcast_dag_status never raises even when internals fail."""
|
|
from reflector.hatchet.dag_progress import broadcast_dag_status
|
|
|
|
with patch(
|
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
|
side_effect=RuntimeError("db exploded"),
|
|
):
|
|
# Should not raise
|
|
await broadcast_dag_status("t-123", "wf-abc")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_broadcast_when_transcript_not_found(self):
|
|
"""broadcast_dag_status does not broadcast if transcript is None."""
|
|
mock_details = _make_details(
|
|
shape=[_make_shape_item("s1", "get_recording")],
|
|
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
|
)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
|
|
|
with (
|
|
patch(
|
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
|
return_value=mock_client,
|
|
),
|
|
patch(
|
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
|
return_value=AsyncContextManager(),
|
|
),
|
|
patch(
|
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
|
new_callable=AsyncMock,
|
|
return_value=None,
|
|
),
|
|
patch(
|
|
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
|
new_callable=AsyncMock,
|
|
) as mock_broadcast,
|
|
):
|
|
from reflector.hatchet.dag_progress import broadcast_dag_status
|
|
|
|
await broadcast_dag_status("t-123", "wf-abc")
|
|
|
|
mock_broadcast.assert_not_called()
|
|
|
|
|
|
class TestMakeAudioProgressLoggerWithBroadcast:
|
|
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _setup_hatchet_mock(self):
|
|
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
|
|
from reflector.hatchet.client import HatchetClientManager
|
|
|
|
original = HatchetClientManager._instance
|
|
if original is None:
|
|
HatchetClientManager._instance = MagicMock()
|
|
yield
|
|
HatchetClientManager._instance = original
|
|
|
|
def test_broadcasts_transient_progress_event(self):
|
|
"""When transcript_id provided and progress_pct not None, broadcasts event."""
|
|
import asyncio
|
|
|
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
make_audio_progress_logger,
|
|
)
|
|
|
|
ctx = MagicMock()
|
|
ctx.log = MagicMock()
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
mock_broadcast = AsyncMock()
|
|
tasks_created = []
|
|
|
|
original_create_task = loop.create_task
|
|
|
|
def capture_create_task(coro):
|
|
task = original_create_task(coro)
|
|
tasks_created.append(task)
|
|
return task
|
|
|
|
try:
|
|
with (
|
|
patch(
|
|
"reflector.hatchet.broadcast.broadcast_event",
|
|
mock_broadcast,
|
|
),
|
|
patch.object(loop, "create_task", side_effect=capture_create_task),
|
|
):
|
|
callback = make_audio_progress_logger(
|
|
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
|
)
|
|
callback(50.0, 100.0)
|
|
|
|
# Run pending tasks
|
|
if tasks_created:
|
|
loop.run_until_complete(asyncio.gather(*tasks_created))
|
|
|
|
mock_broadcast.assert_called_once()
|
|
event_arg = mock_broadcast.call_args[0][1]
|
|
assert event_arg.event == "DAG_TASK_PROGRESS"
|
|
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
|
|
assert event_arg.data["progress_pct"] == 50.0
|
|
finally:
|
|
loop.close()
|
|
|
|
def test_no_broadcast_without_transcript_id(self):
|
|
"""When transcript_id is None, no broadcast happens."""
|
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
make_audio_progress_logger,
|
|
)
|
|
|
|
ctx = MagicMock()
|
|
|
|
with patch(
|
|
"reflector.hatchet.broadcast.broadcast_event",
|
|
new_callable=AsyncMock,
|
|
) as mock_broadcast:
|
|
callback = make_audio_progress_logger(
|
|
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
|
|
)
|
|
callback(50.0, 100.0)
|
|
mock_broadcast.assert_not_called()
|
|
|
|
def test_no_broadcast_when_progress_pct_is_none(self):
|
|
"""When progress_pct is None, no broadcast happens even with transcript_id."""
|
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
make_audio_progress_logger,
|
|
)
|
|
|
|
ctx = MagicMock()
|
|
|
|
with patch(
|
|
"reflector.hatchet.broadcast.broadcast_event",
|
|
new_callable=AsyncMock,
|
|
) as mock_broadcast:
|
|
callback = make_audio_progress_logger(
|
|
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
|
)
|
|
callback(None, 100.0)
|
|
mock_broadcast.assert_not_called()
|
|
|
|
def test_logging_throttled_by_interval(self):
|
|
"""With interval=5.0, rapid calls only log once until interval elapses.
|
|
|
|
The throttle applies to ctx.log() calls. Broadcasts (fire-and-forget)
|
|
are not throttled — they occur every call when transcript_id + progress_pct set.
|
|
"""
|
|
import asyncio
|
|
import time as time_mod
|
|
|
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
make_audio_progress_logger,
|
|
)
|
|
|
|
ctx = MagicMock()
|
|
ctx.log = MagicMock()
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
mock_broadcast = AsyncMock()
|
|
tasks_created = []
|
|
original_create_task = loop.create_task
|
|
|
|
def capture_create_task(coro):
|
|
task = original_create_task(coro)
|
|
tasks_created.append(task)
|
|
return task
|
|
|
|
# Controlled monotonic values for the 4 calls from make_audio_progress_logger:
|
|
# init (start_time, last_log_time), call1 (now), call2 (now), call3 (now)
|
|
# After those, fall back to real time.monotonic() for asyncio internals.
|
|
controlled_values = [100.0, 100.0, 101.0, 106.0]
|
|
call_index = [0]
|
|
real_monotonic = time_mod.monotonic
|
|
|
|
def mock_monotonic():
|
|
if call_index[0] < len(controlled_values):
|
|
val = controlled_values[call_index[0]]
|
|
call_index[0] += 1
|
|
return val
|
|
return real_monotonic()
|
|
|
|
try:
|
|
with (
|
|
patch(
|
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.time.monotonic",
|
|
side_effect=mock_monotonic,
|
|
),
|
|
patch(
|
|
"reflector.hatchet.broadcast.broadcast_event",
|
|
mock_broadcast,
|
|
),
|
|
patch.object(loop, "create_task", side_effect=capture_create_task),
|
|
):
|
|
callback = make_audio_progress_logger(
|
|
ctx, TaskName.MIXDOWN_TRACKS, interval=5.0, transcript_id="t-123"
|
|
)
|
|
|
|
# Call 1 at t=100.0: 100.0 - 100.0 = 0.0 < 5.0 => no log
|
|
callback(25.0, 50.0)
|
|
assert ctx.log.call_count == 0
|
|
|
|
# Call 2 at t=101.0: 101.0 - 100.0 = 1.0 < 5.0 => no log
|
|
callback(50.0, 100.0)
|
|
assert ctx.log.call_count == 0
|
|
|
|
# Call 3 at t=106.0: 106.0 - 100.0 = 6.0 >= 5.0 => logs
|
|
callback(75.0, 150.0)
|
|
assert ctx.log.call_count == 1
|
|
|
|
# Run pending broadcast tasks
|
|
if tasks_created:
|
|
loop.run_until_complete(asyncio.gather(*tasks_created))
|
|
|
|
# Broadcasts happen on every call (not throttled) — 3 calls total
|
|
assert mock_broadcast.call_count == 3
|
|
finally:
|
|
loop.close()
|
|
|
|
def test_uses_broadcast_event_not_append_event_and_broadcast(self):
|
|
"""Progress events use broadcast_event (transient), not append_event_and_broadcast (persisted)."""
|
|
import asyncio
|
|
|
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
make_audio_progress_logger,
|
|
)
|
|
|
|
ctx = MagicMock()
|
|
ctx.log = MagicMock()
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
mock_broadcast_event = AsyncMock()
|
|
mock_append = AsyncMock()
|
|
tasks_created = []
|
|
original_create_task = loop.create_task
|
|
|
|
def capture_create_task(coro):
|
|
task = original_create_task(coro)
|
|
tasks_created.append(task)
|
|
return task
|
|
|
|
try:
|
|
with (
|
|
patch(
|
|
"reflector.hatchet.broadcast.broadcast_event",
|
|
mock_broadcast_event,
|
|
),
|
|
patch(
|
|
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
|
mock_append,
|
|
),
|
|
patch.object(loop, "create_task", side_effect=capture_create_task),
|
|
):
|
|
callback = make_audio_progress_logger(
|
|
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
|
)
|
|
callback(50.0, 100.0)
|
|
|
|
if tasks_created:
|
|
loop.run_until_complete(asyncio.gather(*tasks_created))
|
|
|
|
# broadcast_event (transient) IS called
|
|
mock_broadcast_event.assert_called_once()
|
|
# append_event_and_broadcast (persisted) is NOT called
|
|
mock_append.assert_not_called()
|
|
finally:
|
|
loop.close()
|