mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-22 07:06:47 +00:00
feat: add broadcast_dag_status, decorator integration, and mixdown progress
- Add broadcast_dag_status() to dag_progress.py: fetches Hatchet run details, transforms to DagStatusData, and broadcasts DAG_STATUS event via WebSocket. Fire-and-forget with exception swallowing. - Modify with_error_handling decorator to call broadcast_dag_status on both task success and failure. - Add DAG_STATUS to USER_ROOM_EVENTS (broadcast.py) and reconnect filter (transcripts_websocket.py) to avoid replaying stale DAG state. - Add initial DAG broadcast at workflow dispatch (transcript_process.py). - Extend make_audio_progress_logger with optional transcript_id param for transient DAG_TASK_PROGRESS events during mixdown. - All deferred imports for fork-safety, all broadcasts fire-and-forget.
This commit is contained in:
@@ -15,7 +15,7 @@ from reflector.utils.string import NonEmptyString
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
|
||||
# Events that should also be sent to user room (matches Celery behavior)
|
||||
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
|
||||
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION", "DAG_STATUS"}
|
||||
|
||||
|
||||
async def broadcast_event(
|
||||
|
||||
@@ -187,3 +187,44 @@ def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]:
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def broadcast_dag_status(transcript_id: str, workflow_run_id: str) -> None:
|
||||
"""Fetch current DAG state from Hatchet and broadcast via WebSocket.
|
||||
|
||||
Fire-and-forget: exceptions are logged but never raised.
|
||||
All imports are deferred for fork-safety (Hatchet workers fork processes).
|
||||
"""
|
||||
try:
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: I001, PLC0415
|
||||
from reflector.hatchet.broadcast import append_event_and_broadcast # noqa: PLC0415
|
||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
fresh_db_connection,
|
||||
) # noqa: PLC0415
|
||||
from reflector.logger import logger # noqa: PLC0415
|
||||
|
||||
async with fresh_db_connection():
|
||||
client = HatchetClientManager.get_client()
|
||||
details = await client.runs.aio_get(workflow_run_id)
|
||||
dag_tasks = extract_dag_tasks(details)
|
||||
dag_status = DagStatusData(workflow_run_id=workflow_run_id, tasks=dag_tasks)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if transcript:
|
||||
await append_event_and_broadcast(
|
||||
transcript_id,
|
||||
transcript,
|
||||
"DAG_STATUS",
|
||||
dag_status.model_dump(mode="json"),
|
||||
logger,
|
||||
)
|
||||
except Exception:
|
||||
from reflector.logger import logger # noqa: PLC0415
|
||||
|
||||
logger.warning(
|
||||
"[DAG Progress] Failed to broadcast DAG status",
|
||||
transcript_id=transcript_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -184,7 +184,10 @@ class Loggable(Protocol):
|
||||
|
||||
|
||||
def make_audio_progress_logger(
|
||||
ctx: Loggable, task_name: TaskName, interval: float = 5.0
|
||||
ctx: Loggable,
|
||||
task_name: TaskName,
|
||||
interval: float = 5.0,
|
||||
transcript_id: str | None = None,
|
||||
) -> Callable[[float | None, float], None]:
|
||||
"""Create a throttled progress logger callback for audio processing.
|
||||
|
||||
@@ -192,6 +195,7 @@ def make_audio_progress_logger(
|
||||
ctx: Object with .log() method (e.g., Hatchet Context).
|
||||
task_name: Name to prefix in log messages.
|
||||
interval: Minimum seconds between log messages.
|
||||
transcript_id: If provided, broadcasts transient DAG_TASK_PROGRESS events.
|
||||
|
||||
Returns:
|
||||
Callback(progress_pct, audio_position) that logs at most every `interval` seconds.
|
||||
@@ -213,6 +217,27 @@ def make_audio_progress_logger(
|
||||
)
|
||||
last_log_time[0] = now
|
||||
|
||||
if transcript_id and progress_pct is not None:
|
||||
try:
|
||||
import asyncio # noqa: PLC0415
|
||||
|
||||
from reflector.db.transcripts import TranscriptEvent # noqa: PLC0415
|
||||
from reflector.hatchet.broadcast import broadcast_event # noqa: PLC0415
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(
|
||||
broadcast_event(
|
||||
transcript_id,
|
||||
TranscriptEvent(
|
||||
event="DAG_TASK_PROGRESS",
|
||||
data={"task_name": task_name, "progress_pct": progress_pct},
|
||||
),
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass # transient, never fail the callback
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
@@ -237,8 +262,15 @@ def with_error_handling(
|
||||
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(input: PipelineInput, ctx: Context) -> R:
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
|
||||
|
||||
try:
|
||||
return await func(input, ctx)
|
||||
result = await func(input, ctx)
|
||||
try:
|
||||
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Hatchet] {step_name} failed",
|
||||
@@ -246,6 +278,10 @@ def with_error_handling(
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
|
||||
except Exception:
|
||||
pass
|
||||
if set_error_status:
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
raise
|
||||
@@ -560,7 +596,9 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
target_sample_rate,
|
||||
offsets_seconds=None,
|
||||
logger=logger,
|
||||
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
|
||||
progress_callback=make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, transcript_id=input.transcript_id
|
||||
),
|
||||
expected_duration_sec=recording_duration if recording_duration > 0 else None,
|
||||
)
|
||||
await writer.flush()
|
||||
|
||||
@@ -267,6 +267,19 @@ async def dispatch_transcript_processing(
|
||||
)
|
||||
|
||||
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
||||
|
||||
try:
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
|
||||
|
||||
await broadcast_dag_status(config.transcript_id, workflow_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[DAG Progress] Failed initial broadcast",
|
||||
transcript_id=config.transcript_id,
|
||||
workflow_id=workflow_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
elif isinstance(config, FileProcessingConfig):
|
||||
|
||||
@@ -45,7 +45,7 @@ async def transcript_events_websocket(
|
||||
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
||||
# not necessary to be sent to the client; but keep the rest
|
||||
name = event.event
|
||||
if name in ("TRANSCRIPT", "STATUS"):
|
||||
if name in ("TRANSCRIPT", "STATUS", "DAG_STATUS"):
|
||||
continue
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -5,8 +5,11 @@ into structured DagTask list for frontend consumption.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.hatchet.constants import TaskName
|
||||
from reflector.hatchet.dag_progress import (
|
||||
DagStatusData,
|
||||
DagTask,
|
||||
@@ -387,3 +390,234 @@ class TestDagStatusData:
|
||||
assert dumped["tasks"][0]["name"] == "get_recording"
|
||||
assert dumped["tasks"][0]["status"] == "completed"
|
||||
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|
||||
|
||||
|
||||
class AsyncContextManager:
|
||||
"""No-op async context manager for mocking fresh_db_connection."""
|
||||
|
||||
async def __aenter__(self):
|
||||
return None
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class TestBroadcastDagStatus:
|
||||
"""Test broadcast_dag_status function.
|
||||
|
||||
broadcast_dag_status uses deferred imports inside its function body.
|
||||
We mock the source modules/objects before calling the function.
|
||||
Importing daily_multitrack_pipeline triggers a cascade
|
||||
(subject_processing -> HatchetClientManager.get_client at module level),
|
||||
so we set _instance before the import to prevent real SDK init.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_hatchet_mock(self):
|
||||
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
|
||||
|
||||
Module-level code in workflow files calls get_client() during import.
|
||||
Setting _instance before import avoids ClientConfig validation.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
original = HatchetClientManager._instance
|
||||
HatchetClientManager._instance = MagicMock()
|
||||
yield
|
||||
HatchetClientManager._instance = original
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcasts_dag_status(self):
|
||||
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
|
||||
mock_transcript = MagicMock()
|
||||
mock_transcript.id = "t-123"
|
||||
|
||||
mock_details = _make_details(
|
||||
shape=[_make_shape_item("s1", "get_recording")],
|
||||
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||
run_id="wf-abc",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast,
|
||||
patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_transcript,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
return_value=AsyncContextManager(),
|
||||
),
|
||||
):
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert call_args[0][0] == "t-123" # transcript_id
|
||||
assert call_args[0][1] is mock_transcript # transcript
|
||||
assert call_args[0][2] == "DAG_STATUS" # event_name
|
||||
data = call_args[0][3]
|
||||
assert data["workflow_run_id"] == "wf-abc"
|
||||
assert len(data["tasks"]) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_exceptions(self):
|
||||
"""broadcast_dag_status never raises even when internals fail."""
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
side_effect=RuntimeError("db exploded"),
|
||||
):
|
||||
# Should not raise
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_broadcast_when_transcript_not_found(self):
|
||||
"""broadcast_dag_status does not broadcast if transcript is None."""
|
||||
mock_details = _make_details(
|
||||
shape=[_make_shape_item("s1", "get_recording")],
|
||||
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
return_value=AsyncContextManager(),
|
||||
),
|
||||
patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast,
|
||||
):
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
|
||||
class TestMakeAudioProgressLoggerWithBroadcast:
|
||||
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_hatchet_mock(self):
|
||||
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
original = HatchetClientManager._instance
|
||||
if original is None:
|
||||
HatchetClientManager._instance = MagicMock()
|
||||
yield
|
||||
HatchetClientManager._instance = original
|
||||
|
||||
def test_broadcasts_transient_progress_event(self):
|
||||
"""When transcript_id provided and progress_pct not None, broadcasts event."""
|
||||
import asyncio
|
||||
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
mock_broadcast = AsyncMock()
|
||||
tasks_created = []
|
||||
|
||||
original_create_task = loop.create_task
|
||||
|
||||
def capture_create_task(coro):
|
||||
task = original_create_task(coro)
|
||||
tasks_created.append(task)
|
||||
return task
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
mock_broadcast,
|
||||
),
|
||||
patch.object(loop, "create_task", side_effect=capture_create_task),
|
||||
):
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||
)
|
||||
callback(50.0, 100.0)
|
||||
|
||||
# Run pending tasks
|
||||
if tasks_created:
|
||||
loop.run_until_complete(asyncio.gather(*tasks_created))
|
||||
|
||||
mock_broadcast.assert_called_once()
|
||||
event_arg = mock_broadcast.call_args[0][1]
|
||||
assert event_arg.event == "DAG_TASK_PROGRESS"
|
||||
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
|
||||
assert event_arg.data["progress_pct"] == 50.0
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_no_broadcast_without_transcript_id(self):
|
||||
"""When transcript_id is None, no broadcast happens."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast:
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
|
||||
)
|
||||
callback(50.0, 100.0)
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
def test_no_broadcast_when_progress_pct_is_none(self):
|
||||
"""When progress_pct is None, no broadcast happens even with transcript_id."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast:
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||
)
|
||||
callback(None, 100.0)
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user