feat: add broadcast_dag_status, decorator integration, and mixdown progress

- Add broadcast_dag_status() to dag_progress.py: fetches Hatchet run
  details, transforms to DagStatusData, and broadcasts DAG_STATUS event
  via WebSocket. Fire-and-forget with exception swallowing.
- Modify with_error_handling decorator to call broadcast_dag_status on
  both task success and failure.
- Add DAG_STATUS to USER_ROOM_EVENTS (broadcast.py) and reconnect
  filter (transcripts_websocket.py) to avoid replaying stale DAG state.
- Add initial DAG broadcast at workflow dispatch (transcript_process.py).
- Extend make_audio_progress_logger with optional transcript_id param
  for transient DAG_TASK_PROGRESS events during mixdown.
- All deferred imports for fork-safety, all broadcasts fire-and-forget.
This commit is contained in:
Igor Loskutov
2026-02-09 13:09:26 -05:00
parent a359c845ff
commit 4b79b0c989
6 changed files with 332 additions and 6 deletions

View File

@@ -15,7 +15,7 @@ from reflector.utils.string import NonEmptyString
from reflector.ws_manager import get_ws_manager
# Events that should also be sent to user room (matches Celery behavior)
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION", "DAG_STATUS"}
async def broadcast_event(

View File

@@ -187,3 +187,44 @@ def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]:
)
return result
async def broadcast_dag_status(transcript_id: str, workflow_run_id: str) -> None:
"""Fetch current DAG state from Hatchet and broadcast via WebSocket.
Fire-and-forget: exceptions are logged but never raised.
All imports are deferred for fork-safety (Hatchet workers fork processes).
"""
try:
from reflector.db.transcripts import transcripts_controller # noqa: I001, PLC0415
from reflector.hatchet.broadcast import append_event_and_broadcast # noqa: PLC0415
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
fresh_db_connection,
) # noqa: PLC0415
from reflector.logger import logger # noqa: PLC0415
async with fresh_db_connection():
client = HatchetClientManager.get_client()
details = await client.runs.aio_get(workflow_run_id)
dag_tasks = extract_dag_tasks(details)
dag_status = DagStatusData(workflow_run_id=workflow_run_id, tasks=dag_tasks)
transcript = await transcripts_controller.get_by_id(transcript_id)
if transcript:
await append_event_and_broadcast(
transcript_id,
transcript,
"DAG_STATUS",
dag_status.model_dump(mode="json"),
logger,
)
except Exception:
from reflector.logger import logger # noqa: PLC0415
logger.warning(
"[DAG Progress] Failed to broadcast DAG status",
transcript_id=transcript_id,
workflow_run_id=workflow_run_id,
exc_info=True,
)

View File

@@ -184,7 +184,10 @@ class Loggable(Protocol):
def make_audio_progress_logger(
ctx: Loggable, task_name: TaskName, interval: float = 5.0
ctx: Loggable,
task_name: TaskName,
interval: float = 5.0,
transcript_id: str | None = None,
) -> Callable[[float | None, float], None]:
"""Create a throttled progress logger callback for audio processing.
@@ -192,6 +195,7 @@ def make_audio_progress_logger(
ctx: Object with .log() method (e.g., Hatchet Context).
task_name: Name to prefix in log messages.
interval: Minimum seconds between log messages.
transcript_id: If provided, broadcasts transient DAG_TASK_PROGRESS events.
Returns:
Callback(progress_pct, audio_position) that logs at most every `interval` seconds.
@@ -213,6 +217,27 @@ def make_audio_progress_logger(
)
last_log_time[0] = now
if transcript_id and progress_pct is not None:
try:
import asyncio # noqa: PLC0415
from reflector.db.transcripts import TranscriptEvent # noqa: PLC0415
from reflector.hatchet.broadcast import broadcast_event # noqa: PLC0415
loop = asyncio.get_event_loop()
loop.create_task(
broadcast_event(
transcript_id,
TranscriptEvent(
event="DAG_TASK_PROGRESS",
data={"task_name": task_name, "progress_pct": progress_pct},
),
logger=logger,
)
)
except Exception:
pass # transient, never fail the callback
return callback
@@ -237,8 +262,15 @@ def with_error_handling(
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context) -> R:
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
try:
return await func(input, ctx)
result = await func(input, ctx)
try:
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
except Exception:
pass
return result
except Exception as e:
logger.error(
f"[Hatchet] {step_name} failed",
@@ -246,6 +278,10 @@ def with_error_handling(
error=str(e),
exc_info=True,
)
try:
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
except Exception:
pass
if set_error_status:
await set_workflow_error_status(input.transcript_id)
raise
@@ -560,7 +596,9 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
target_sample_rate,
offsets_seconds=None,
logger=logger,
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
progress_callback=make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, transcript_id=input.transcript_id
),
expected_duration_sec=recording_duration if recording_duration > 0 else None,
)
await writer.flush()

View File

@@ -267,6 +267,19 @@ async def dispatch_transcript_processing(
)
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
try:
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
await broadcast_dag_status(config.transcript_id, workflow_id)
except Exception:
logger.warning(
"[DAG Progress] Failed initial broadcast",
transcript_id=config.transcript_id,
workflow_id=workflow_id,
exc_info=True,
)
return None
elif isinstance(config, FileProcessingConfig):

View File

@@ -45,7 +45,7 @@ async def transcript_events_websocket(
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
if name in ("TRANSCRIPT", "STATUS", "DAG_STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))

View File

@@ -5,8 +5,11 @@ into structured DagTask list for frontend consumption.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from reflector.hatchet.constants import TaskName
from reflector.hatchet.dag_progress import (
DagStatusData,
DagTask,
@@ -387,3 +390,234 @@ class TestDagStatusData:
assert dumped["tasks"][0]["name"] == "get_recording"
assert dumped["tasks"][0]["status"] == "completed"
assert dumped["tasks"][0]["duration_seconds"] == 1.0
class AsyncContextManager:
"""No-op async context manager for mocking fresh_db_connection."""
async def __aenter__(self):
return None
async def __aexit__(self, *args):
return None
class TestBroadcastDagStatus:
"""Test broadcast_dag_status function.
broadcast_dag_status uses deferred imports inside its function body.
We mock the source modules/objects before calling the function.
Importing daily_multitrack_pipeline triggers a cascade
(subject_processing -> HatchetClientManager.get_client at module level),
so we set _instance before the import to prevent real SDK init.
"""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
Module-level code in workflow files calls get_client() during import.
Setting _instance before import avoids ClientConfig validation.
"""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
@pytest.mark.asyncio
async def test_broadcasts_dag_status(self):
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
mock_transcript = MagicMock()
mock_transcript.id = "t-123"
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
run_id="wf-abc",
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=mock_transcript,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
mock_broadcast.assert_called_once()
call_args = mock_broadcast.call_args
assert call_args[0][0] == "t-123" # transcript_id
assert call_args[0][1] is mock_transcript # transcript
assert call_args[0][2] == "DAG_STATUS" # event_name
data = call_args[0][3]
assert data["workflow_run_id"] == "wf-abc"
assert len(data["tasks"]) == 1
@pytest.mark.asyncio
async def test_swallows_exceptions(self):
"""broadcast_dag_status never raises even when internals fail."""
from reflector.hatchet.dag_progress import broadcast_dag_status
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
side_effect=RuntimeError("db exploded"),
):
# Should not raise
await broadcast_dag_status("t-123", "wf-abc")
@pytest.mark.asyncio
async def test_no_broadcast_when_transcript_not_found(self):
"""broadcast_dag_status does not broadcast if transcript is None."""
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=None,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_broadcast.assert_not_called()
class TestMakeAudioProgressLoggerWithBroadcast:
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
if original is None:
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
def test_broadcasts_transient_progress_event(self):
"""When transcript_id provided and progress_pct not None, broadcasts event."""
import asyncio
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
try:
with (
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(50.0, 100.0)
# Run pending tasks
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
mock_broadcast.assert_called_once()
event_arg = mock_broadcast.call_args[0][1]
assert event_arg.event == "DAG_TASK_PROGRESS"
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
assert event_arg.data["progress_pct"] == 50.0
finally:
loop.close()
def test_no_broadcast_without_transcript_id(self):
"""When transcript_id is None, no broadcast happens."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
)
callback(50.0, 100.0)
mock_broadcast.assert_not_called()
def test_no_broadcast_when_progress_pct_is_none(self):
"""When progress_pct is None, no broadcast happens even with transcript_id."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(None, 100.0)
mock_broadcast.assert_not_called()