Files
reflector/server/tests/test_dag_progress.py
Igor Loskutov 2410688559 fix: pass DagStatusData model instead of dict to append_event_and_broadcast
add_event() calls .model_dump() on data, so it needs a Pydantic model not a dict.
2026-02-09 14:00:43 -05:00

960 lines
35 KiB
Python

"""Tests for DAG progress models and transform function.
Tests the extract_dag_tasks function that converts Hatchet V1WorkflowRunDetails
into structured DagTask list for frontend consumption.
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from reflector.hatchet.constants import TaskName
from reflector.hatchet.dag_progress import (
DagStatusData,
DagTask,
DagTaskStatus,
extract_dag_tasks,
)
def _make_shape_item(
step_id: str,
task_name: str,
children_step_ids: list[str] | None = None,
) -> MagicMock:
"""Create a mock WorkflowRunShapeItemForWorkflowRunDetails."""
item = MagicMock()
item.step_id = step_id
item.task_name = task_name
item.children_step_ids = children_step_ids or []
return item
def _make_task_summary(
step_id: str,
status: str = "QUEUED",
started_at: datetime | None = None,
finished_at: datetime | None = None,
duration: int | None = None,
error_message: str | None = None,
task_external_id: str | None = None,
num_spawned_children: int | None = None,
children: list | None = None,
) -> MagicMock:
"""Create a mock V1TaskSummary."""
from hatchet_sdk.clients.rest.models import V1TaskStatus
task = MagicMock()
task.step_id = step_id
task.status = V1TaskStatus(status)
task.started_at = started_at
task.finished_at = finished_at
task.duration = duration
task.error_message = error_message
task.task_external_id = task_external_id or f"ext-{step_id}"
task.num_spawned_children = num_spawned_children
task.children = children or []
return task
def _make_details(
shape: list,
tasks: list,
run_id: str = "test-run-id",
) -> MagicMock:
"""Create a mock V1WorkflowRunDetails."""
details = MagicMock()
details.shape = shape
details.tasks = tasks
details.task_events = []
details.run = MagicMock()
details.run.metadata = MagicMock()
details.run.metadata.id = run_id
return details
class TestExtractDagTasksBasic:
"""Test basic extraction of DAG tasks from workflow run details."""
def test_empty_shape_returns_empty_list(self):
details = _make_details(shape=[], tasks=[])
result = extract_dag_tasks(details)
assert result == []
def test_single_task_queued(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [_make_task_summary("s1", status="QUEUED")]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert len(result) == 1
assert result[0].name == "get_recording"
assert result[0].status == DagTaskStatus.QUEUED
assert result[0].parents == []
assert result[0].started_at is None
assert result[0].finished_at is None
assert result[0].duration_seconds is None
assert result[0].error is None
assert result[0].children_total is None
assert result[0].children_completed is None
assert result[0].progress_pct is None
def test_completed_task_with_duration(self):
now = datetime.now(timezone.utc)
shape = [_make_shape_item("s1", "get_recording")]
tasks = [
_make_task_summary(
"s1",
status="COMPLETED",
started_at=now,
finished_at=now,
duration=1500, # milliseconds
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.COMPLETED
assert result[0].duration_seconds == 1.5
assert result[0].started_at == now
assert result[0].finished_at == now
def test_failed_task_with_error(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [
_make_task_summary(
"s1",
status="FAILED",
error_message="Traceback (most recent call last):\n File something\nConnectionError: connection refused",
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.FAILED
assert result[0].error == "ConnectionError: connection refused"
def test_running_task(self):
now = datetime.now(timezone.utc)
shape = [_make_shape_item("s1", "mixdown_tracks")]
tasks = [
_make_task_summary(
"s1",
status="RUNNING",
started_at=now,
duration=5000,
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.RUNNING
assert result[0].started_at == now
assert result[0].duration_seconds == 5.0
def test_cancelled_task(self):
shape = [_make_shape_item("s1", "post_zulip")]
tasks = [_make_task_summary("s1", status="CANCELLED")]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.CANCELLED
class TestExtractDagTasksTopology:
"""Test topological ordering and parent extraction."""
def test_linear_chain_parents(self):
"""A -> B -> C should produce correct parents."""
shape = [
_make_shape_item("s1", "get_recording", children_step_ids=["s2"]),
_make_shape_item("s2", "get_participants", children_step_ids=["s3"]),
_make_shape_item("s3", "process_tracks"),
]
tasks = [
_make_task_summary("s1", status="COMPLETED"),
_make_task_summary("s2", status="COMPLETED"),
_make_task_summary("s3", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert [t.name for t in result] == [
"get_recording",
"get_participants",
"process_tracks",
]
assert result[0].parents == []
assert result[1].parents == ["get_recording"]
assert result[2].parents == ["get_participants"]
def test_diamond_dag(self):
"""
A -> B, A -> C, B -> D, C -> D
D should have parents [B, C] (or [C, B] depending on sort).
"""
shape = [
_make_shape_item("s1", "get_recording", children_step_ids=["s2", "s3"]),
_make_shape_item("s2", "mixdown_tracks", children_step_ids=["s4"]),
_make_shape_item("s3", "detect_topics", children_step_ids=["s4"]),
_make_shape_item("s4", "finalize"),
]
tasks = [
_make_task_summary("s1", status="COMPLETED"),
_make_task_summary("s2", status="RUNNING"),
_make_task_summary("s3", status="RUNNING"),
_make_task_summary("s4", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
# Topological: s1 first, s2/s3 in some order, s4 last
assert result[0].name == "get_recording"
assert result[-1].name == "finalize"
finalize = result[-1]
assert set(finalize.parents) == {"mixdown_tracks", "detect_topics"}
def test_topological_order_is_stable(self):
"""Verify deterministic ordering (sorted queue in Kahn's)."""
shape = [
_make_shape_item("s_c", "task_c"),
_make_shape_item("s_a", "task_a", children_step_ids=["s_c"]),
_make_shape_item("s_b", "task_b", children_step_ids=["s_c"]),
]
tasks = [
_make_task_summary("s_c", status="QUEUED"),
_make_task_summary("s_a", status="COMPLETED"),
_make_task_summary("s_b", status="COMPLETED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
# s_a and s_b both roots with in-degree 0; sorted alphabetically by step_id
names = [t.name for t in result]
assert names[-1] == "task_c"
# First two should be task_a, task_b (sorted by step_id: s_a < s_b)
assert names[0] == "task_a"
assert names[1] == "task_b"
def test_production_dag_shape(self):
"""Test the real 15-task pipeline topology with mixed statuses.
Simulates a mid-pipeline state where early tasks completed,
middle tasks running, and later tasks still queued.
"""
# Production DAG edges (parent -> children):
# get_recording -> get_participants
# get_participants -> process_tracks
# process_tracks -> mixdown_tracks, detect_topics, finalize
# mixdown_tracks -> generate_waveform
# detect_topics -> generate_title, extract_subjects
# extract_subjects -> process_subjects, identify_action_items
# process_subjects -> generate_recap
# generate_title -> finalize
# generate_recap -> finalize
# identify_action_items -> finalize
# finalize -> cleanup_consent
# cleanup_consent -> post_zulip, send_webhook
shape = [
_make_shape_item(
"s_get_recording", TaskName.GET_RECORDING, ["s_get_participants"]
),
_make_shape_item(
"s_get_participants", TaskName.GET_PARTICIPANTS, ["s_process_tracks"]
),
_make_shape_item(
"s_process_tracks",
TaskName.PROCESS_TRACKS,
["s_mixdown_tracks", "s_detect_topics", "s_finalize"],
),
_make_shape_item(
"s_mixdown_tracks", TaskName.MIXDOWN_TRACKS, ["s_generate_waveform"]
),
_make_shape_item("s_generate_waveform", TaskName.GENERATE_WAVEFORM),
_make_shape_item(
"s_detect_topics",
TaskName.DETECT_TOPICS,
["s_generate_title", "s_extract_subjects"],
),
_make_shape_item(
"s_generate_title", TaskName.GENERATE_TITLE, ["s_finalize"]
),
_make_shape_item(
"s_extract_subjects",
TaskName.EXTRACT_SUBJECTS,
["s_process_subjects", "s_identify_action_items"],
),
_make_shape_item(
"s_process_subjects", TaskName.PROCESS_SUBJECTS, ["s_generate_recap"]
),
_make_shape_item(
"s_generate_recap", TaskName.GENERATE_RECAP, ["s_finalize"]
),
_make_shape_item(
"s_identify_action_items",
TaskName.IDENTIFY_ACTION_ITEMS,
["s_finalize"],
),
_make_shape_item("s_finalize", TaskName.FINALIZE, ["s_cleanup_consent"]),
_make_shape_item(
"s_cleanup_consent",
TaskName.CLEANUP_CONSENT,
["s_post_zulip", "s_send_webhook"],
),
_make_shape_item("s_post_zulip", TaskName.POST_ZULIP),
_make_shape_item("s_send_webhook", TaskName.SEND_WEBHOOK),
]
# Mid-pipeline: early tasks done, middle running, later queued
tasks = [
_make_task_summary("s_get_recording", status="COMPLETED"),
_make_task_summary("s_get_participants", status="COMPLETED"),
_make_task_summary("s_process_tracks", status="COMPLETED"),
_make_task_summary("s_mixdown_tracks", status="RUNNING"),
_make_task_summary("s_generate_waveform", status="QUEUED"),
_make_task_summary("s_detect_topics", status="RUNNING"),
_make_task_summary("s_generate_title", status="QUEUED"),
_make_task_summary("s_extract_subjects", status="QUEUED"),
_make_task_summary("s_process_subjects", status="QUEUED"),
_make_task_summary("s_generate_recap", status="QUEUED"),
_make_task_summary("s_identify_action_items", status="QUEUED"),
_make_task_summary("s_finalize", status="QUEUED"),
_make_task_summary("s_cleanup_consent", status="QUEUED"),
_make_task_summary("s_post_zulip", status="QUEUED"),
_make_task_summary("s_send_webhook", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
# All 15 tasks present
assert len(result) == 15
result_names = [t.name for t in result]
assert set(result_names) == {
TaskName.GET_RECORDING,
TaskName.GET_PARTICIPANTS,
TaskName.PROCESS_TRACKS,
TaskName.MIXDOWN_TRACKS,
TaskName.GENERATE_WAVEFORM,
TaskName.DETECT_TOPICS,
TaskName.GENERATE_TITLE,
TaskName.EXTRACT_SUBJECTS,
TaskName.PROCESS_SUBJECTS,
TaskName.GENERATE_RECAP,
TaskName.IDENTIFY_ACTION_ITEMS,
TaskName.FINALIZE,
TaskName.CLEANUP_CONSENT,
TaskName.POST_ZULIP,
TaskName.SEND_WEBHOOK,
}
# Topological order invariant: no task appears before its parents
name_to_index = {t.name: i for i, t in enumerate(result)}
for task in result:
for parent_name in task.parents:
assert name_to_index[parent_name] < name_to_index[task.name], (
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
f"must appear before {task.name} (idx {name_to_index[task.name]})"
)
# finalize has exactly 4 parents
finalize = next(t for t in result if t.name == TaskName.FINALIZE)
assert set(finalize.parents) == {
TaskName.PROCESS_TRACKS,
TaskName.GENERATE_TITLE,
TaskName.GENERATE_RECAP,
TaskName.IDENTIFY_ACTION_ITEMS,
}
# cleanup_consent has 1 parent (finalize)
cleanup = next(t for t in result if t.name == TaskName.CLEANUP_CONSENT)
assert cleanup.parents == [TaskName.FINALIZE]
# post_zulip and send_webhook both have cleanup_consent as parent
post_zulip = next(t for t in result if t.name == TaskName.POST_ZULIP)
send_webhook = next(t for t in result if t.name == TaskName.SEND_WEBHOOK)
assert post_zulip.parents == [TaskName.CLEANUP_CONSENT]
assert send_webhook.parents == [TaskName.CLEANUP_CONSENT]
# Verify statuses propagated correctly
assert (
next(t for t in result if t.name == TaskName.GET_RECORDING).status
== DagTaskStatus.COMPLETED
)
assert (
next(t for t in result if t.name == TaskName.MIXDOWN_TRACKS).status
== DagTaskStatus.RUNNING
)
assert (
next(t for t in result if t.name == TaskName.FINALIZE).status
== DagTaskStatus.QUEUED
)
def test_topological_sort_invariant_complex_dag(self):
"""For a complex DAG, every task's parents appear earlier in the list.
Uses a wider branching/merging DAG than diamond to stress the invariant.
"""
# DAG: A -> B, A -> C, A -> D, B -> E, C -> E, C -> F, D -> F, E -> G, F -> G
shape = [
_make_shape_item("s_a", "task_a", ["s_b", "s_c", "s_d"]),
_make_shape_item("s_b", "task_b", ["s_e"]),
_make_shape_item("s_c", "task_c", ["s_e", "s_f"]),
_make_shape_item("s_d", "task_d", ["s_f"]),
_make_shape_item("s_e", "task_e", ["s_g"]),
_make_shape_item("s_f", "task_f", ["s_g"]),
_make_shape_item("s_g", "task_g"),
]
tasks = [
_make_task_summary("s_a", status="COMPLETED"),
_make_task_summary("s_b", status="COMPLETED"),
_make_task_summary("s_c", status="RUNNING"),
_make_task_summary("s_d", status="COMPLETED"),
_make_task_summary("s_e", status="QUEUED"),
_make_task_summary("s_f", status="QUEUED"),
_make_task_summary("s_g", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert len(result) == 7
name_to_index = {t.name: i for i, t in enumerate(result)}
# Verify invariant: every parent appears before its child
for task in result:
for parent_name in task.parents:
assert name_to_index[parent_name] < name_to_index[task.name], (
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
f"must appear before {task.name} (idx {name_to_index[task.name]})"
)
# task_g has 2 parents
task_g = next(t for t in result if t.name == "task_g")
assert set(task_g.parents) == {"task_e", "task_f"}
# task_e has 2 parents
task_e = next(t for t in result if t.name == "task_e")
assert set(task_e.parents) == {"task_b", "task_c"}
# task_a is root (first in topological order)
assert result[0].name == "task_a"
assert result[0].parents == []
class TestExtractDagTasksFanOut:
"""Test fan-out tasks with spawned children."""
def test_fan_out_children_counts(self):
from hatchet_sdk.clients.rest.models import V1TaskStatus
child_mocks = []
for status in ["COMPLETED", "COMPLETED", "RUNNING", "QUEUED"]:
child = MagicMock()
child.status = V1TaskStatus(status)
child_mocks.append(child)
shape = [_make_shape_item("s1", "process_tracks")]
tasks = [
_make_task_summary(
"s1",
status="RUNNING",
num_spawned_children=4,
children=child_mocks,
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].children_total == 4
assert result[0].children_completed == 2
def test_no_children_when_no_spawn(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [
_make_task_summary("s1", status="COMPLETED", num_spawned_children=None)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].children_total is None
assert result[0].children_completed is None
def test_zero_spawned_children(self):
shape = [_make_shape_item("s1", "process_tracks")]
tasks = [_make_task_summary("s1", status="COMPLETED", num_spawned_children=0)]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].children_total is None
assert result[0].children_completed is None
class TestExtractDagTasksErrorExtraction:
"""Test error message extraction logic."""
def test_simple_error(self):
shape = [_make_shape_item("s1", "mixdown_tracks")]
tasks = [
_make_task_summary(
"s1", status="FAILED", error_message="ValueError: no tracks"
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error == "ValueError: no tracks"
def test_traceback_extracts_meaningful_line(self):
error = (
"Traceback (most recent call last):\n"
' File "/app/something.py", line 42\n'
"RuntimeError: out of memory"
)
shape = [_make_shape_item("s1", "mixdown_tracks")]
tasks = [_make_task_summary("s1", status="FAILED", error_message=error)]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error == "RuntimeError: out of memory"
def test_no_error_when_none(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [_make_task_summary("s1", status="COMPLETED", error_message=None)]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error is None
def test_empty_error_message(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [_make_task_summary("s1", status="FAILED", error_message="")]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error is None
class TestExtractDagTasksMissingData:
"""Test edge cases with missing task data."""
def test_shape_without_matching_task(self):
"""Shape has a step but tasks list doesn't contain it."""
shape = [_make_shape_item("s1", "get_recording")]
tasks = [] # No matching task
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert len(result) == 1
assert result[0].name == "get_recording"
assert result[0].status == DagTaskStatus.QUEUED # default when no task data
assert result[0].started_at is None
def test_none_shape_returns_empty(self):
details = _make_details(shape=[], tasks=[])
details.shape = None
result = extract_dag_tasks(details)
assert result == []
class TestDagStatusData:
"""Test DagStatusData model serialization."""
def test_serialization(self):
task = DagTask(
name="get_recording",
status=DagTaskStatus.COMPLETED,
started_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
finished_at=datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc),
duration_seconds=1.0,
parents=[],
error=None,
children_total=None,
children_completed=None,
progress_pct=None,
)
data = DagStatusData(workflow_run_id="test-123", tasks=[task])
dumped = data.model_dump(mode="json")
assert dumped["workflow_run_id"] == "test-123"
assert len(dumped["tasks"]) == 1
assert dumped["tasks"][0]["name"] == "get_recording"
assert dumped["tasks"][0]["status"] == "completed"
assert dumped["tasks"][0]["duration_seconds"] == 1.0
class AsyncContextManager:
"""No-op async context manager for mocking fresh_db_connection."""
async def __aenter__(self):
return None
async def __aexit__(self, *args):
return None
class TestBroadcastDagStatus:
"""Test broadcast_dag_status function.
broadcast_dag_status uses deferred imports inside its function body.
We mock the source modules/objects before calling the function.
Importing daily_multitrack_pipeline triggers a cascade
(subject_processing -> HatchetClientManager.get_client at module level),
so we set _instance before the import to prevent real SDK init.
"""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
Module-level code in workflow files calls get_client() during import.
Setting _instance before import avoids ClientConfig validation.
"""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
@pytest.mark.asyncio
async def test_broadcasts_dag_status(self):
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
mock_transcript = MagicMock()
mock_transcript.id = "t-123"
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
run_id="wf-abc",
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=mock_transcript,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
mock_broadcast.assert_called_once()
call_args = mock_broadcast.call_args
assert call_args[0][0] == "t-123" # transcript_id
assert call_args[0][1] is mock_transcript # transcript
assert call_args[0][2] == "DAG_STATUS" # event_name
data = call_args[0][3]
assert isinstance(data, DagStatusData)
assert data.workflow_run_id == "wf-abc"
assert len(data.tasks) == 1
@pytest.mark.asyncio
async def test_swallows_exceptions(self):
"""broadcast_dag_status never raises even when internals fail."""
from reflector.hatchet.dag_progress import broadcast_dag_status
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
side_effect=RuntimeError("db exploded"),
):
# Should not raise
await broadcast_dag_status("t-123", "wf-abc")
@pytest.mark.asyncio
async def test_no_broadcast_when_transcript_not_found(self):
"""broadcast_dag_status does not broadcast if transcript is None."""
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=None,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_broadcast.assert_not_called()
class TestMakeAudioProgressLoggerWithBroadcast:
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
if original is None:
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
def test_broadcasts_transient_progress_event(self):
"""When transcript_id provided and progress_pct not None, broadcasts event."""
import asyncio
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
try:
with (
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(50.0, 100.0)
# Run pending tasks
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
mock_broadcast.assert_called_once()
event_arg = mock_broadcast.call_args[0][1]
assert event_arg.event == "DAG_TASK_PROGRESS"
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
assert event_arg.data["progress_pct"] == 50.0
finally:
loop.close()
def test_no_broadcast_without_transcript_id(self):
"""When transcript_id is None, no broadcast happens."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
)
callback(50.0, 100.0)
mock_broadcast.assert_not_called()
def test_no_broadcast_when_progress_pct_is_none(self):
"""When progress_pct is None, no broadcast happens even with transcript_id."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(None, 100.0)
mock_broadcast.assert_not_called()
def test_logging_throttled_by_interval(self):
"""With interval=5.0, rapid calls only log once until interval elapses.
The throttle applies to ctx.log() calls. Broadcasts (fire-and-forget)
are not throttled — they occur every call when transcript_id + progress_pct set.
"""
import asyncio
import time as time_mod
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
# Controlled monotonic values for the 4 calls from make_audio_progress_logger:
# init (start_time, last_log_time), call1 (now), call2 (now), call3 (now)
# After those, fall back to real time.monotonic() for asyncio internals.
controlled_values = [100.0, 100.0, 101.0, 106.0]
call_index = [0]
real_monotonic = time_mod.monotonic
def mock_monotonic():
if call_index[0] < len(controlled_values):
val = controlled_values[call_index[0]]
call_index[0] += 1
return val
return real_monotonic()
try:
with (
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.time.monotonic",
side_effect=mock_monotonic,
),
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=5.0, transcript_id="t-123"
)
# Call 1 at t=100.0: 100.0 - 100.0 = 0.0 < 5.0 => no log
callback(25.0, 50.0)
assert ctx.log.call_count == 0
# Call 2 at t=101.0: 101.0 - 100.0 = 1.0 < 5.0 => no log
callback(50.0, 100.0)
assert ctx.log.call_count == 0
# Call 3 at t=106.0: 106.0 - 100.0 = 6.0 >= 5.0 => logs
callback(75.0, 150.0)
assert ctx.log.call_count == 1
# Run pending broadcast tasks
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
# Broadcasts happen on every call (not throttled) — 3 calls total
assert mock_broadcast.call_count == 3
finally:
loop.close()
def test_uses_broadcast_event_not_append_event_and_broadcast(self):
"""Progress events use broadcast_event (transient), not append_event_and_broadcast (persisted)."""
import asyncio
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast_event = AsyncMock()
mock_append = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
try:
with (
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast_event,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
mock_append,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(50.0, 100.0)
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
# broadcast_event (transient) IS called
mock_broadcast_event.assert_called_once()
# append_event_and_broadcast (persisted) is NOT called
mock_append.assert_not_called()
finally:
loop.close()