mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-26 07:05:19 +00:00
feat: add broadcast_dag_status, decorator integration, and mixdown progress
- Add broadcast_dag_status() to dag_progress.py: fetches Hatchet run details, transforms to DagStatusData, and broadcasts DAG_STATUS event via WebSocket. Fire-and-forget with exception swallowing. - Modify with_error_handling decorator to call broadcast_dag_status on both task success and failure. - Add DAG_STATUS to USER_ROOM_EVENTS (broadcast.py) and reconnect filter (transcripts_websocket.py) to avoid replaying stale DAG state. - Add initial DAG broadcast at workflow dispatch (transcript_process.py). - Extend make_audio_progress_logger with optional transcript_id param for transient DAG_TASK_PROGRESS events during mixdown. - All deferred imports for fork-safety, all broadcasts fire-and-forget.
This commit is contained in:
@@ -15,7 +15,7 @@ from reflector.utils.string import NonEmptyString
|
|||||||
from reflector.ws_manager import get_ws_manager
|
from reflector.ws_manager import get_ws_manager
|
||||||
|
|
||||||
# Events that should also be sent to user room (matches Celery behavior)
|
# Events that should also be sent to user room (matches Celery behavior)
|
||||||
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
|
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION", "DAG_STATUS"}
|
||||||
|
|
||||||
|
|
||||||
async def broadcast_event(
|
async def broadcast_event(
|
||||||
|
|||||||
@@ -187,3 +187,44 @@ def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def broadcast_dag_status(transcript_id: str, workflow_run_id: str) -> None:
|
||||||
|
"""Fetch current DAG state from Hatchet and broadcast via WebSocket.
|
||||||
|
|
||||||
|
Fire-and-forget: exceptions are logged but never raised.
|
||||||
|
All imports are deferred for fork-safety (Hatchet workers fork processes).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from reflector.db.transcripts import transcripts_controller # noqa: I001, PLC0415
|
||||||
|
from reflector.hatchet.broadcast import append_event_and_broadcast # noqa: PLC0415
|
||||||
|
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
fresh_db_connection,
|
||||||
|
) # noqa: PLC0415
|
||||||
|
from reflector.logger import logger # noqa: PLC0415
|
||||||
|
|
||||||
|
async with fresh_db_connection():
|
||||||
|
client = HatchetClientManager.get_client()
|
||||||
|
details = await client.runs.aio_get(workflow_run_id)
|
||||||
|
dag_tasks = extract_dag_tasks(details)
|
||||||
|
dag_status = DagStatusData(workflow_run_id=workflow_run_id, tasks=dag_tasks)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
|
if transcript:
|
||||||
|
await append_event_and_broadcast(
|
||||||
|
transcript_id,
|
||||||
|
transcript,
|
||||||
|
"DAG_STATUS",
|
||||||
|
dag_status.model_dump(mode="json"),
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
from reflector.logger import logger # noqa: PLC0415
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"[DAG Progress] Failed to broadcast DAG status",
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -184,7 +184,10 @@ class Loggable(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
def make_audio_progress_logger(
|
def make_audio_progress_logger(
|
||||||
ctx: Loggable, task_name: TaskName, interval: float = 5.0
|
ctx: Loggable,
|
||||||
|
task_name: TaskName,
|
||||||
|
interval: float = 5.0,
|
||||||
|
transcript_id: str | None = None,
|
||||||
) -> Callable[[float | None, float], None]:
|
) -> Callable[[float | None, float], None]:
|
||||||
"""Create a throttled progress logger callback for audio processing.
|
"""Create a throttled progress logger callback for audio processing.
|
||||||
|
|
||||||
@@ -192,6 +195,7 @@ def make_audio_progress_logger(
|
|||||||
ctx: Object with .log() method (e.g., Hatchet Context).
|
ctx: Object with .log() method (e.g., Hatchet Context).
|
||||||
task_name: Name to prefix in log messages.
|
task_name: Name to prefix in log messages.
|
||||||
interval: Minimum seconds between log messages.
|
interval: Minimum seconds between log messages.
|
||||||
|
transcript_id: If provided, broadcasts transient DAG_TASK_PROGRESS events.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callback(progress_pct, audio_position) that logs at most every `interval` seconds.
|
Callback(progress_pct, audio_position) that logs at most every `interval` seconds.
|
||||||
@@ -213,6 +217,27 @@ def make_audio_progress_logger(
|
|||||||
)
|
)
|
||||||
last_log_time[0] = now
|
last_log_time[0] = now
|
||||||
|
|
||||||
|
if transcript_id and progress_pct is not None:
|
||||||
|
try:
|
||||||
|
import asyncio # noqa: PLC0415
|
||||||
|
|
||||||
|
from reflector.db.transcripts import TranscriptEvent # noqa: PLC0415
|
||||||
|
from reflector.hatchet.broadcast import broadcast_event # noqa: PLC0415
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.create_task(
|
||||||
|
broadcast_event(
|
||||||
|
transcript_id,
|
||||||
|
TranscriptEvent(
|
||||||
|
event="DAG_TASK_PROGRESS",
|
||||||
|
data={"task_name": task_name, "progress_pct": progress_pct},
|
||||||
|
),
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # transient, never fail the callback
|
||||||
|
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
||||||
@@ -237,8 +262,15 @@ def with_error_handling(
|
|||||||
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
|
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(input: PipelineInput, ctx: Context) -> R:
|
async def wrapper(input: PipelineInput, ctx: Context) -> R:
|
||||||
|
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await func(input, ctx)
|
result = await func(input, ctx)
|
||||||
|
try:
|
||||||
|
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[Hatchet] {step_name} failed",
|
f"[Hatchet] {step_name} failed",
|
||||||
@@ -246,6 +278,10 @@ def with_error_handling(
|
|||||||
error=str(e),
|
error=str(e),
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
if set_error_status:
|
if set_error_status:
|
||||||
await set_workflow_error_status(input.transcript_id)
|
await set_workflow_error_status(input.transcript_id)
|
||||||
raise
|
raise
|
||||||
@@ -560,7 +596,9 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
|||||||
target_sample_rate,
|
target_sample_rate,
|
||||||
offsets_seconds=None,
|
offsets_seconds=None,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
|
progress_callback=make_audio_progress_logger(
|
||||||
|
ctx, TaskName.MIXDOWN_TRACKS, transcript_id=input.transcript_id
|
||||||
|
),
|
||||||
expected_duration_sec=recording_duration if recording_duration > 0 else None,
|
expected_duration_sec=recording_duration if recording_duration > 0 else None,
|
||||||
)
|
)
|
||||||
await writer.flush()
|
await writer.flush()
|
||||||
|
|||||||
@@ -267,6 +267,19 @@ async def dispatch_transcript_processing(
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
|
||||||
|
|
||||||
|
await broadcast_dag_status(config.transcript_id, workflow_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"[DAG Progress] Failed initial broadcast",
|
||||||
|
transcript_id=config.transcript_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif isinstance(config, FileProcessingConfig):
|
elif isinstance(config, FileProcessingConfig):
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ async def transcript_events_websocket(
|
|||||||
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
||||||
# not necessary to be sent to the client; but keep the rest
|
# not necessary to be sent to the client; but keep the rest
|
||||||
name = event.event
|
name = event.event
|
||||||
if name in ("TRANSCRIPT", "STATUS"):
|
if name in ("TRANSCRIPT", "STATUS", "DAG_STATUS"):
|
||||||
continue
|
continue
|
||||||
await websocket.send_json(event.model_dump(mode="json"))
|
await websocket.send_json(event.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,11 @@ into structured DagTask list for frontend consumption.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.hatchet.constants import TaskName
|
||||||
from reflector.hatchet.dag_progress import (
|
from reflector.hatchet.dag_progress import (
|
||||||
DagStatusData,
|
DagStatusData,
|
||||||
DagTask,
|
DagTask,
|
||||||
@@ -387,3 +390,234 @@ class TestDagStatusData:
|
|||||||
assert dumped["tasks"][0]["name"] == "get_recording"
|
assert dumped["tasks"][0]["name"] == "get_recording"
|
||||||
assert dumped["tasks"][0]["status"] == "completed"
|
assert dumped["tasks"][0]["status"] == "completed"
|
||||||
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncContextManager:
|
||||||
|
"""No-op async context manager for mocking fresh_db_connection."""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestBroadcastDagStatus:
|
||||||
|
"""Test broadcast_dag_status function.
|
||||||
|
|
||||||
|
broadcast_dag_status uses deferred imports inside its function body.
|
||||||
|
We mock the source modules/objects before calling the function.
|
||||||
|
Importing daily_multitrack_pipeline triggers a cascade
|
||||||
|
(subject_processing -> HatchetClientManager.get_client at module level),
|
||||||
|
so we set _instance before the import to prevent real SDK init.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _setup_hatchet_mock(self):
|
||||||
|
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
|
||||||
|
|
||||||
|
Module-level code in workflow files calls get_client() during import.
|
||||||
|
Setting _instance before import avoids ClientConfig validation.
|
||||||
|
"""
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
|
original = HatchetClientManager._instance
|
||||||
|
HatchetClientManager._instance = MagicMock()
|
||||||
|
yield
|
||||||
|
HatchetClientManager._instance = original
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcasts_dag_status(self):
|
||||||
|
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
|
||||||
|
mock_transcript = MagicMock()
|
||||||
|
mock_transcript.id = "t-123"
|
||||||
|
|
||||||
|
mock_details = _make_details(
|
||||||
|
shape=[_make_shape_item("s1", "get_recording")],
|
||||||
|
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||||
|
run_id="wf-abc",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||||
|
return_value=mock_client,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_broadcast,
|
||||||
|
patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_transcript,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||||
|
return_value=AsyncContextManager(),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||||
|
|
||||||
|
await broadcast_dag_status("t-123", "wf-abc")
|
||||||
|
|
||||||
|
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
|
||||||
|
mock_broadcast.assert_called_once()
|
||||||
|
call_args = mock_broadcast.call_args
|
||||||
|
assert call_args[0][0] == "t-123" # transcript_id
|
||||||
|
assert call_args[0][1] is mock_transcript # transcript
|
||||||
|
assert call_args[0][2] == "DAG_STATUS" # event_name
|
||||||
|
data = call_args[0][3]
|
||||||
|
assert data["workflow_run_id"] == "wf-abc"
|
||||||
|
assert len(data["tasks"]) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swallows_exceptions(self):
|
||||||
|
"""broadcast_dag_status never raises even when internals fail."""
|
||||||
|
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||||
|
side_effect=RuntimeError("db exploded"),
|
||||||
|
):
|
||||||
|
# Should not raise
|
||||||
|
await broadcast_dag_status("t-123", "wf-abc")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_broadcast_when_transcript_not_found(self):
|
||||||
|
"""broadcast_dag_status does not broadcast if transcript is None."""
|
||||||
|
mock_details = _make_details(
|
||||||
|
shape=[_make_shape_item("s1", "get_recording")],
|
||||||
|
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||||
|
return_value=mock_client,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||||
|
return_value=AsyncContextManager(),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_broadcast,
|
||||||
|
):
|
||||||
|
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||||
|
|
||||||
|
await broadcast_dag_status("t-123", "wf-abc")
|
||||||
|
|
||||||
|
mock_broadcast.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMakeAudioProgressLoggerWithBroadcast:
|
||||||
|
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _setup_hatchet_mock(self):
|
||||||
|
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
|
||||||
|
original = HatchetClientManager._instance
|
||||||
|
if original is None:
|
||||||
|
HatchetClientManager._instance = MagicMock()
|
||||||
|
yield
|
||||||
|
HatchetClientManager._instance = original
|
||||||
|
|
||||||
|
def test_broadcasts_transient_progress_event(self):
|
||||||
|
"""When transcript_id provided and progress_pct not None, broadcasts event."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
make_audio_progress_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.log = MagicMock()
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
mock_broadcast = AsyncMock()
|
||||||
|
tasks_created = []
|
||||||
|
|
||||||
|
original_create_task = loop.create_task
|
||||||
|
|
||||||
|
def capture_create_task(coro):
|
||||||
|
task = original_create_task(coro)
|
||||||
|
tasks_created.append(task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
try:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"reflector.hatchet.broadcast.broadcast_event",
|
||||||
|
mock_broadcast,
|
||||||
|
),
|
||||||
|
patch.object(loop, "create_task", side_effect=capture_create_task),
|
||||||
|
):
|
||||||
|
callback = make_audio_progress_logger(
|
||||||
|
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||||
|
)
|
||||||
|
callback(50.0, 100.0)
|
||||||
|
|
||||||
|
# Run pending tasks
|
||||||
|
if tasks_created:
|
||||||
|
loop.run_until_complete(asyncio.gather(*tasks_created))
|
||||||
|
|
||||||
|
mock_broadcast.assert_called_once()
|
||||||
|
event_arg = mock_broadcast.call_args[0][1]
|
||||||
|
assert event_arg.event == "DAG_TASK_PROGRESS"
|
||||||
|
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
|
||||||
|
assert event_arg.data["progress_pct"] == 50.0
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def test_no_broadcast_without_transcript_id(self):
|
||||||
|
"""When transcript_id is None, no broadcast happens."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
make_audio_progress_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = MagicMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.broadcast.broadcast_event",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_broadcast:
|
||||||
|
callback = make_audio_progress_logger(
|
||||||
|
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
|
||||||
|
)
|
||||||
|
callback(50.0, 100.0)
|
||||||
|
mock_broadcast.assert_not_called()
|
||||||
|
|
||||||
|
def test_no_broadcast_when_progress_pct_is_none(self):
|
||||||
|
"""When progress_pct is None, no broadcast happens even with transcript_id."""
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
make_audio_progress_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = MagicMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.hatchet.broadcast.broadcast_event",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_broadcast:
|
||||||
|
callback = make_audio_progress_logger(
|
||||||
|
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||||
|
)
|
||||||
|
callback(None, 100.0)
|
||||||
|
mock_broadcast.assert_not_called()
|
||||||
|
|||||||
Reference in New Issue
Block a user