diff --git a/server/reflector/hatchet/broadcast.py b/server/reflector/hatchet/broadcast.py index 6b42ddbd..84bcccb6 100644 --- a/server/reflector/hatchet/broadcast.py +++ b/server/reflector/hatchet/broadcast.py @@ -15,7 +15,7 @@ from reflector.utils.string import NonEmptyString from reflector.ws_manager import get_ws_manager # Events that should also be sent to user room (matches Celery behavior) -USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"} +USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION", "DAG_STATUS"} async def broadcast_event( diff --git a/server/reflector/hatchet/dag_progress.py b/server/reflector/hatchet/dag_progress.py index 1dc65116..97d29ef1 100644 --- a/server/reflector/hatchet/dag_progress.py +++ b/server/reflector/hatchet/dag_progress.py @@ -187,3 +187,44 @@ def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]: ) return result + + +async def broadcast_dag_status(transcript_id: str, workflow_run_id: str) -> None: + """Fetch current DAG state from Hatchet and broadcast via WebSocket. + + Fire-and-forget: exceptions are logged but never raised. + All imports are deferred for fork-safety (Hatchet workers fork processes). + """ + try: + from reflector.db.transcripts import transcripts_controller # noqa: I001, PLC0415 + from reflector.hatchet.broadcast import append_event_and_broadcast # noqa: PLC0415 + from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415 + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + fresh_db_connection, + ) # noqa: PLC0415 + from reflector.logger import logger # noqa: PLC0415 + + async with fresh_db_connection(): + client = HatchetClientManager.get_client() + details = await client.runs.aio_get(workflow_run_id) + dag_tasks = extract_dag_tasks(details) + dag_status = DagStatusData(workflow_run_id=workflow_run_id, tasks=dag_tasks) + + transcript = await transcripts_controller.get_by_id(transcript_id) + if transcript: + await append_event_and_broadcast( + transcript_id, + transcript, + "DAG_STATUS", + dag_status.model_dump(mode="json"), + logger, + ) + except Exception: + from reflector.logger import logger # noqa: PLC0415 + + logger.warning( + "[DAG Progress] Failed to broadcast DAG status", + transcript_id=transcript_id, + workflow_run_id=workflow_run_id, + exc_info=True, + ) diff --git a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py index 188133c7..662d9512 100644 --- a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py +++ b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py @@ -184,7 +184,10 @@ class Loggable(Protocol): def make_audio_progress_logger( - ctx: Loggable, task_name: TaskName, interval: float = 5.0 + ctx: Loggable, + task_name: TaskName, + interval: float = 5.0, + transcript_id: str | None = None, ) -> Callable[[float | None, float], None]: """Create a throttled progress logger callback for audio processing. @@ -192,6 +195,7 @@ def make_audio_progress_logger( ctx: Object with .log() method (e.g., Hatchet Context). task_name: Name to prefix in log messages. interval: Minimum seconds between log messages. + transcript_id: If provided, broadcasts transient DAG_TASK_PROGRESS events. Returns: Callback(progress_pct, audio_position) that logs at most every `interval` seconds. @@ -213,6 +217,27 @@ def make_audio_progress_logger( ) last_log_time[0] = now + if transcript_id and progress_pct is not None: + try: + import asyncio # noqa: PLC0415 + + from reflector.db.transcripts import TranscriptEvent # noqa: PLC0415 + from reflector.hatchet.broadcast import broadcast_event # noqa: PLC0415 + + loop = asyncio.get_event_loop() + loop.create_task( + broadcast_event( + transcript_id, + TranscriptEvent( + event="DAG_TASK_PROGRESS", + data={"task_name": task_name, "progress_pct": progress_pct}, + ), + logger=logger, + ) + ) + except Exception: + pass # transient, never fail the callback + return callback @@ -237,8 +262,15 @@ def with_error_handling( ) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]: @functools.wraps(func) async def wrapper(input: PipelineInput, ctx: Context) -> R: + from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415 + try: - return await func(input, ctx) + result = await func(input, ctx) + try: + await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id) + except Exception: + pass + return result except Exception as e: logger.error( f"[Hatchet] {step_name} failed", @@ -246,6 +278,10 @@ def with_error_handling( error=str(e), exc_info=True, ) + try: + await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id) + except Exception: + pass if set_error_status: await set_workflow_error_status(input.transcript_id) raise @@ -560,7 +596,9 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: target_sample_rate, offsets_seconds=None, logger=logger, - progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS), + progress_callback=make_audio_progress_logger( + ctx, TaskName.MIXDOWN_TRACKS, transcript_id=input.transcript_id + ), expected_duration_sec=recording_duration if recording_duration > 0 else None, ) await writer.flush() diff --git a/server/reflector/services/transcript_process.py b/server/reflector/services/transcript_process.py index 13847a49..977704c4 100644 --- a/server/reflector/services/transcript_process.py +++ b/server/reflector/services/transcript_process.py @@ -267,6 +267,19 @@ async def dispatch_transcript_processing( ) logger.info("Hatchet workflow dispatched", workflow_id=workflow_id) + + try: + from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415 + + await broadcast_dag_status(config.transcript_id, workflow_id) + except Exception: + logger.warning( + "[DAG Progress] Failed initial broadcast", + transcript_id=config.transcript_id, + workflow_id=workflow_id, + exc_info=True, + ) + return None elif isinstance(config, FileProcessingConfig): diff --git a/server/reflector/views/transcripts_websocket.py b/server/reflector/views/transcripts_websocket.py index ccb7d7ff..fae25eed 100644 --- a/server/reflector/views/transcripts_websocket.py +++ b/server/reflector/views/transcripts_websocket.py @@ -45,7 +45,7 @@ async def transcript_events_websocket( # for now, do not send TRANSCRIPT or STATUS options - theses are live event # not necessary to be sent to the client; but keep the rest name = event.event - if name in ("TRANSCRIPT", "STATUS"): + if name in ("TRANSCRIPT", "STATUS", "DAG_STATUS"): continue await websocket.send_json(event.model_dump(mode="json")) diff --git a/server/tests/test_dag_progress.py b/server/tests/test_dag_progress.py index b031b7f1..14a2d810 100644 --- a/server/tests/test_dag_progress.py +++ b/server/tests/test_dag_progress.py @@ -5,8 +5,11 @@ into structured DagTask list for frontend consumption. """ from datetime import datetime, timezone -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + +from reflector.hatchet.constants import TaskName from reflector.hatchet.dag_progress import ( DagStatusData, DagTask, @@ -387,3 +390,234 @@ class TestDagStatusData: assert dumped["tasks"][0]["name"] == "get_recording" assert dumped["tasks"][0]["status"] == "completed" assert dumped["tasks"][0]["duration_seconds"] == 1.0 + + +class AsyncContextManager: + """No-op async context manager for mocking fresh_db_connection.""" + + async def __aenter__(self): + return None + + async def __aexit__(self, *args): + return None + + +class TestBroadcastDagStatus: + """Test broadcast_dag_status function. + + broadcast_dag_status uses deferred imports inside its function body. + We mock the source modules/objects before calling the function. + Importing daily_multitrack_pipeline triggers a cascade + (subject_processing -> HatchetClientManager.get_client at module level), + so we set _instance before the import to prevent real SDK init. + """ + + @pytest.fixture(autouse=True) + def _setup_hatchet_mock(self): + """Set HatchetClientManager._instance to a mock to prevent real SDK init. + + Module-level code in workflow files calls get_client() during import. + Setting _instance before import avoids ClientConfig validation. + """ + from reflector.hatchet.client import HatchetClientManager + + original = HatchetClientManager._instance + HatchetClientManager._instance = MagicMock() + yield + HatchetClientManager._instance = original + + @pytest.mark.asyncio + async def test_broadcasts_dag_status(self): + """broadcast_dag_status fetches run, transforms, and broadcasts.""" + mock_transcript = MagicMock() + mock_transcript.id = "t-123" + + mock_details = _make_details( + shape=[_make_shape_item("s1", "get_recording")], + tasks=[_make_task_summary("s1", status="COMPLETED")], + run_id="wf-abc", + ) + + mock_client = MagicMock() + mock_client.runs.aio_get = AsyncMock(return_value=mock_details) + + with ( + patch( + "reflector.hatchet.client.HatchetClientManager.get_client", + return_value=mock_client, + ), + patch( + "reflector.hatchet.broadcast.append_event_and_broadcast", + new_callable=AsyncMock, + ) as mock_broadcast, + patch( + "reflector.db.transcripts.transcripts_controller.get_by_id", + new_callable=AsyncMock, + return_value=mock_transcript, + ), + patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection", + return_value=AsyncContextManager(), + ), + ): + from reflector.hatchet.dag_progress import broadcast_dag_status + + await broadcast_dag_status("t-123", "wf-abc") + + mock_client.runs.aio_get.assert_called_once_with("wf-abc") + mock_broadcast.assert_called_once() + call_args = mock_broadcast.call_args + assert call_args[0][0] == "t-123" # transcript_id + assert call_args[0][1] is mock_transcript # transcript + assert call_args[0][2] == "DAG_STATUS" # event_name + data = call_args[0][3] + assert data["workflow_run_id"] == "wf-abc" + assert len(data["tasks"]) == 1 + + @pytest.mark.asyncio + async def test_swallows_exceptions(self): + """broadcast_dag_status never raises even when internals fail.""" + from reflector.hatchet.dag_progress import broadcast_dag_status + + with patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection", + side_effect=RuntimeError("db exploded"), + ): + # Should not raise + await broadcast_dag_status("t-123", "wf-abc") + + @pytest.mark.asyncio + async def test_no_broadcast_when_transcript_not_found(self): + """broadcast_dag_status does not broadcast if transcript is None.""" + mock_details = _make_details( + shape=[_make_shape_item("s1", "get_recording")], + tasks=[_make_task_summary("s1", status="COMPLETED")], + ) + + mock_client = MagicMock() + mock_client.runs.aio_get = AsyncMock(return_value=mock_details) + + with ( + patch( + "reflector.hatchet.client.HatchetClientManager.get_client", + return_value=mock_client, + ), + patch( + "reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection", + return_value=AsyncContextManager(), + ), + patch( + "reflector.db.transcripts.transcripts_controller.get_by_id", + new_callable=AsyncMock, + return_value=None, + ), + patch( + "reflector.hatchet.broadcast.append_event_and_broadcast", + new_callable=AsyncMock, + ) as mock_broadcast, + ): + from reflector.hatchet.dag_progress import broadcast_dag_status + + await broadcast_dag_status("t-123", "wf-abc") + + mock_broadcast.assert_not_called() + + +class TestMakeAudioProgressLoggerWithBroadcast: + """Test make_audio_progress_logger with transcript_id for transient broadcasts.""" + + @pytest.fixture(autouse=True) + def _setup_hatchet_mock(self): + """Set HatchetClientManager._instance to prevent real SDK init on import.""" + from reflector.hatchet.client import HatchetClientManager + + original = HatchetClientManager._instance + if original is None: + HatchetClientManager._instance = MagicMock() + yield + HatchetClientManager._instance = original + + def test_broadcasts_transient_progress_event(self): + """When transcript_id provided and progress_pct not None, broadcasts event.""" + import asyncio + + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + make_audio_progress_logger, + ) + + ctx = MagicMock() + ctx.log = MagicMock() + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + mock_broadcast = AsyncMock() + tasks_created = [] + + original_create_task = loop.create_task + + def capture_create_task(coro): + task = original_create_task(coro) + tasks_created.append(task) + return task + + try: + with ( + patch( + "reflector.hatchet.broadcast.broadcast_event", + mock_broadcast, + ), + patch.object(loop, "create_task", side_effect=capture_create_task), + ): + callback = make_audio_progress_logger( + ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123" + ) + callback(50.0, 100.0) + + # Run pending tasks + if tasks_created: + loop.run_until_complete(asyncio.gather(*tasks_created)) + + mock_broadcast.assert_called_once() + event_arg = mock_broadcast.call_args[0][1] + assert event_arg.event == "DAG_TASK_PROGRESS" + assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS + assert event_arg.data["progress_pct"] == 50.0 + finally: + loop.close() + + def test_no_broadcast_without_transcript_id(self): + """When transcript_id is None, no broadcast happens.""" + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + make_audio_progress_logger, + ) + + ctx = MagicMock() + + with patch( + "reflector.hatchet.broadcast.broadcast_event", + new_callable=AsyncMock, + ) as mock_broadcast: + callback = make_audio_progress_logger( + ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None + ) + callback(50.0, 100.0) + mock_broadcast.assert_not_called() + + def test_no_broadcast_when_progress_pct_is_none(self): + """When progress_pct is None, no broadcast happens even with transcript_id.""" + from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + make_audio_progress_logger, + ) + + ctx = MagicMock() + + with patch( + "reflector.hatchet.broadcast.broadcast_event", + new_callable=AsyncMock, + ) as mock_broadcast: + callback = make_audio_progress_logger( + ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123" + ) + callback(None, 100.0) + mock_broadcast.assert_not_called()