feat: add broadcast_dag_status, decorator integration, and mixdown progress

- Add broadcast_dag_status() to dag_progress.py: fetches Hatchet run
  details, transforms to DagStatusData, and broadcasts DAG_STATUS event
  via WebSocket. Fire-and-forget with exception swallowing.
- Modify with_error_handling decorator to call broadcast_dag_status on
  both task success and failure.
- Add DAG_STATUS to USER_ROOM_EVENTS (broadcast.py) and reconnect
  filter (transcripts_websocket.py) to avoid replaying stale DAG state.
- Add initial DAG broadcast at workflow dispatch (transcript_process.py).
- Extend make_audio_progress_logger with optional transcript_id param
  for transient DAG_TASK_PROGRESS events during mixdown.
- All deferred imports for fork-safety, all broadcasts fire-and-forget.
This commit is contained in:
Igor Loskutov
2026-02-09 13:09:26 -05:00
parent a359c845ff
commit 4b79b0c989
6 changed files with 332 additions and 6 deletions

View File

@@ -5,8 +5,11 @@ into structured DagTask list for frontend consumption.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from reflector.hatchet.constants import TaskName
from reflector.hatchet.dag_progress import (
DagStatusData,
DagTask,
@@ -387,3 +390,234 @@ class TestDagStatusData:
assert dumped["tasks"][0]["name"] == "get_recording"
assert dumped["tasks"][0]["status"] == "completed"
assert dumped["tasks"][0]["duration_seconds"] == 1.0
class AsyncContextManager:
"""No-op async context manager for mocking fresh_db_connection."""
async def __aenter__(self):
return None
async def __aexit__(self, *args):
return None
class TestBroadcastDagStatus:
"""Test broadcast_dag_status function.
broadcast_dag_status uses deferred imports inside its function body.
We mock the source modules/objects before calling the function.
Importing daily_multitrack_pipeline triggers a cascade
(subject_processing -> HatchetClientManager.get_client at module level),
so we set _instance before the import to prevent real SDK init.
"""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
Module-level code in workflow files calls get_client() during import.
Setting _instance before import avoids ClientConfig validation.
"""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
@pytest.mark.asyncio
async def test_broadcasts_dag_status(self):
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
mock_transcript = MagicMock()
mock_transcript.id = "t-123"
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
run_id="wf-abc",
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=mock_transcript,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
mock_broadcast.assert_called_once()
call_args = mock_broadcast.call_args
assert call_args[0][0] == "t-123" # transcript_id
assert call_args[0][1] is mock_transcript # transcript
assert call_args[0][2] == "DAG_STATUS" # event_name
data = call_args[0][3]
assert data["workflow_run_id"] == "wf-abc"
assert len(data["tasks"]) == 1
@pytest.mark.asyncio
async def test_swallows_exceptions(self):
"""broadcast_dag_status never raises even when internals fail."""
from reflector.hatchet.dag_progress import broadcast_dag_status
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
side_effect=RuntimeError("db exploded"),
):
# Should not raise
await broadcast_dag_status("t-123", "wf-abc")
@pytest.mark.asyncio
async def test_no_broadcast_when_transcript_not_found(self):
"""broadcast_dag_status does not broadcast if transcript is None."""
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=None,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_broadcast.assert_not_called()
class TestMakeAudioProgressLoggerWithBroadcast:
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
if original is None:
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
def test_broadcasts_transient_progress_event(self):
"""When transcript_id provided and progress_pct not None, broadcasts event."""
import asyncio
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
try:
with (
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(50.0, 100.0)
# Run pending tasks
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
mock_broadcast.assert_called_once()
event_arg = mock_broadcast.call_args[0][1]
assert event_arg.event == "DAG_TASK_PROGRESS"
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
assert event_arg.data["progress_pct"] == 50.0
finally:
loop.close()
def test_no_broadcast_without_transcript_id(self):
"""When transcript_id is None, no broadcast happens."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
)
callback(50.0, 100.0)
mock_broadcast.assert_not_called()
def test_no_broadcast_when_progress_pct_is_none(self):
"""When progress_pct is None, no broadcast happens even with transcript_id."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(None, 100.0)
mock_broadcast.assert_not_called()