mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-18 11:16:55 +00:00
feat: add broadcast_dag_status, decorator integration, and mixdown progress
- Add broadcast_dag_status() to dag_progress.py: fetches Hatchet run details, transforms to DagStatusData, and broadcasts DAG_STATUS event via WebSocket. Fire-and-forget with exception swallowing. - Modify with_error_handling decorator to call broadcast_dag_status on both task success and failure. - Add DAG_STATUS to USER_ROOM_EVENTS (broadcast.py) and reconnect filter (transcripts_websocket.py) to avoid replaying stale DAG state. - Add initial DAG broadcast at workflow dispatch (transcript_process.py). - Extend make_audio_progress_logger with optional transcript_id param for transient DAG_TASK_PROGRESS events during mixdown. - All deferred imports for fork-safety, all broadcasts fire-and-forget.
This commit is contained in:
@@ -5,8 +5,11 @@ into structured DagTask list for frontend consumption.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.hatchet.constants import TaskName
|
||||
from reflector.hatchet.dag_progress import (
|
||||
DagStatusData,
|
||||
DagTask,
|
||||
@@ -387,3 +390,234 @@ class TestDagStatusData:
|
||||
assert dumped["tasks"][0]["name"] == "get_recording"
|
||||
assert dumped["tasks"][0]["status"] == "completed"
|
||||
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|
||||
|
||||
|
||||
class AsyncContextManager:
|
||||
"""No-op async context manager for mocking fresh_db_connection."""
|
||||
|
||||
async def __aenter__(self):
|
||||
return None
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class TestBroadcastDagStatus:
|
||||
"""Test broadcast_dag_status function.
|
||||
|
||||
broadcast_dag_status uses deferred imports inside its function body.
|
||||
We mock the source modules/objects before calling the function.
|
||||
Importing daily_multitrack_pipeline triggers a cascade
|
||||
(subject_processing -> HatchetClientManager.get_client at module level),
|
||||
so we set _instance before the import to prevent real SDK init.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_hatchet_mock(self):
|
||||
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
|
||||
|
||||
Module-level code in workflow files calls get_client() during import.
|
||||
Setting _instance before import avoids ClientConfig validation.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
original = HatchetClientManager._instance
|
||||
HatchetClientManager._instance = MagicMock()
|
||||
yield
|
||||
HatchetClientManager._instance = original
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcasts_dag_status(self):
|
||||
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
|
||||
mock_transcript = MagicMock()
|
||||
mock_transcript.id = "t-123"
|
||||
|
||||
mock_details = _make_details(
|
||||
shape=[_make_shape_item("s1", "get_recording")],
|
||||
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||
run_id="wf-abc",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast,
|
||||
patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_transcript,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
return_value=AsyncContextManager(),
|
||||
),
|
||||
):
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert call_args[0][0] == "t-123" # transcript_id
|
||||
assert call_args[0][1] is mock_transcript # transcript
|
||||
assert call_args[0][2] == "DAG_STATUS" # event_name
|
||||
data = call_args[0][3]
|
||||
assert data["workflow_run_id"] == "wf-abc"
|
||||
assert len(data["tasks"]) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_exceptions(self):
|
||||
"""broadcast_dag_status never raises even when internals fail."""
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
side_effect=RuntimeError("db exploded"),
|
||||
):
|
||||
# Should not raise
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_broadcast_when_transcript_not_found(self):
|
||||
"""broadcast_dag_status does not broadcast if transcript is None."""
|
||||
mock_details = _make_details(
|
||||
shape=[_make_shape_item("s1", "get_recording")],
|
||||
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
return_value=AsyncContextManager(),
|
||||
),
|
||||
patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast,
|
||||
):
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
|
||||
class TestMakeAudioProgressLoggerWithBroadcast:
|
||||
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_hatchet_mock(self):
|
||||
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
original = HatchetClientManager._instance
|
||||
if original is None:
|
||||
HatchetClientManager._instance = MagicMock()
|
||||
yield
|
||||
HatchetClientManager._instance = original
|
||||
|
||||
def test_broadcasts_transient_progress_event(self):
|
||||
"""When transcript_id provided and progress_pct not None, broadcasts event."""
|
||||
import asyncio
|
||||
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
mock_broadcast = AsyncMock()
|
||||
tasks_created = []
|
||||
|
||||
original_create_task = loop.create_task
|
||||
|
||||
def capture_create_task(coro):
|
||||
task = original_create_task(coro)
|
||||
tasks_created.append(task)
|
||||
return task
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
mock_broadcast,
|
||||
),
|
||||
patch.object(loop, "create_task", side_effect=capture_create_task),
|
||||
):
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||
)
|
||||
callback(50.0, 100.0)
|
||||
|
||||
# Run pending tasks
|
||||
if tasks_created:
|
||||
loop.run_until_complete(asyncio.gather(*tasks_created))
|
||||
|
||||
mock_broadcast.assert_called_once()
|
||||
event_arg = mock_broadcast.call_args[0][1]
|
||||
assert event_arg.event == "DAG_TASK_PROGRESS"
|
||||
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
|
||||
assert event_arg.data["progress_pct"] == 50.0
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_no_broadcast_without_transcript_id(self):
|
||||
"""When transcript_id is None, no broadcast happens."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast:
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
|
||||
)
|
||||
callback(50.0, 100.0)
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
def test_no_broadcast_when_progress_pct_is_none(self):
|
||||
"""When progress_pct is None, no broadcast happens even with transcript_id."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast:
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||
)
|
||||
callback(None, 100.0)
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user