feat: add broadcast_dag_status, decorator integration, and mixdown progress

- Add broadcast_dag_status() to dag_progress.py: fetches Hatchet run
  details, transforms to DagStatusData, and broadcasts DAG_STATUS event
  via WebSocket. Fire-and-forget with exception swallowing.
- Modify with_error_handling decorator to call broadcast_dag_status on
  both task success and failure.
- Add DAG_STATUS to USER_ROOM_EVENTS (broadcast.py) and reconnect
  filter (transcripts_websocket.py) to avoid replaying stale DAG state.
- Add initial DAG broadcast at workflow dispatch (transcript_process.py).
- Extend make_audio_progress_logger with optional transcript_id param
  for transient DAG_TASK_PROGRESS events during mixdown.
- All deferred imports for fork-safety, all broadcasts fire-and-forget.
This commit is contained in:
Igor Loskutov
2026-02-09 13:09:26 -05:00
parent a359c845ff
commit 4b79b0c989
6 changed files with 332 additions and 6 deletions

View File

@@ -15,7 +15,7 @@ from reflector.utils.string import NonEmptyString
from reflector.ws_manager import get_ws_manager
# Events that should also be sent to user room (matches Celery behavior)
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION", "DAG_STATUS"}
async def broadcast_event(

View File

@@ -187,3 +187,44 @@ def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]:
)
return result
async def broadcast_dag_status(transcript_id: str, workflow_run_id: str) -> None:
"""Fetch current DAG state from Hatchet and broadcast via WebSocket.
Fire-and-forget: exceptions are logged but never raised.
All imports are deferred for fork-safety (Hatchet workers fork processes).
"""
try:
from reflector.db.transcripts import transcripts_controller # noqa: I001, PLC0415
from reflector.hatchet.broadcast import append_event_and_broadcast # noqa: PLC0415
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
fresh_db_connection,
) # noqa: PLC0415
from reflector.logger import logger # noqa: PLC0415
async with fresh_db_connection():
client = HatchetClientManager.get_client()
details = await client.runs.aio_get(workflow_run_id)
dag_tasks = extract_dag_tasks(details)
dag_status = DagStatusData(workflow_run_id=workflow_run_id, tasks=dag_tasks)
transcript = await transcripts_controller.get_by_id(transcript_id)
if transcript:
await append_event_and_broadcast(
transcript_id,
transcript,
"DAG_STATUS",
dag_status.model_dump(mode="json"),
logger,
)
except Exception:
from reflector.logger import logger # noqa: PLC0415
logger.warning(
"[DAG Progress] Failed to broadcast DAG status",
transcript_id=transcript_id,
workflow_run_id=workflow_run_id,
exc_info=True,
)

View File

@@ -184,7 +184,10 @@ class Loggable(Protocol):
def make_audio_progress_logger(
ctx: Loggable, task_name: TaskName, interval: float = 5.0
ctx: Loggable,
task_name: TaskName,
interval: float = 5.0,
transcript_id: str | None = None,
) -> Callable[[float | None, float], None]:
"""Create a throttled progress logger callback for audio processing.
@@ -192,6 +195,7 @@ def make_audio_progress_logger(
ctx: Object with .log() method (e.g., Hatchet Context).
task_name: Name to prefix in log messages.
interval: Minimum seconds between log messages.
transcript_id: If provided, broadcasts transient DAG_TASK_PROGRESS events.
Returns:
Callback(progress_pct, audio_position) that logs at most every `interval` seconds.
@@ -213,6 +217,27 @@ def make_audio_progress_logger(
)
last_log_time[0] = now
if transcript_id and progress_pct is not None:
try:
import asyncio # noqa: PLC0415
from reflector.db.transcripts import TranscriptEvent # noqa: PLC0415
from reflector.hatchet.broadcast import broadcast_event # noqa: PLC0415
loop = asyncio.get_event_loop()
loop.create_task(
broadcast_event(
transcript_id,
TranscriptEvent(
event="DAG_TASK_PROGRESS",
data={"task_name": task_name, "progress_pct": progress_pct},
),
logger=logger,
)
)
except Exception:
pass # transient, never fail the callback
return callback
@@ -237,8 +262,15 @@ def with_error_handling(
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context) -> R:
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
try:
return await func(input, ctx)
result = await func(input, ctx)
try:
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
except Exception:
pass
return result
except Exception as e:
logger.error(
f"[Hatchet] {step_name} failed",
@@ -246,6 +278,10 @@ def with_error_handling(
error=str(e),
exc_info=True,
)
try:
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
except Exception:
pass
if set_error_status:
await set_workflow_error_status(input.transcript_id)
raise
@@ -560,7 +596,9 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
target_sample_rate,
offsets_seconds=None,
logger=logger,
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
progress_callback=make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, transcript_id=input.transcript_id
),
expected_duration_sec=recording_duration if recording_duration > 0 else None,
)
await writer.flush()

View File

@@ -267,6 +267,19 @@ async def dispatch_transcript_processing(
)
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
try:
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
await broadcast_dag_status(config.transcript_id, workflow_id)
except Exception:
logger.warning(
"[DAG Progress] Failed initial broadcast",
transcript_id=config.transcript_id,
workflow_id=workflow_id,
exc_info=True,
)
return None
elif isinstance(config, FileProcessingConfig):

View File

@@ -45,7 +45,7 @@ async def transcript_events_websocket(
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
if name in ("TRANSCRIPT", "STATUS", "DAG_STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))