Compare commits

...

17 Commits

Author SHA1 Message Date
Igor Loskutov
b1eeb651f6 fix: send last DAG_STATUS on WebSocket connect instead of skipping all
Previously all DAG_STATUS events were skipped during historical replay
on WS connect, so reconnecting clients (React strict mode remount,
page navigation) lost current DAG state. Now sends only the most
recent DAG_STATUS event on connect.
2026-02-09 15:26:59 -05:00
Igor Loskutov
499de45fdb fix: processing page reads DAG status from user room WS fallback
The processing page only read dagStatus from the transcript room WS,
which loses events during page navigation and React strict mode
double-mounting (WS torn down and reconnected, historical replay
skips DAG_STATUS). Now also consumes useDagStatusMap() from
UserEventsProvider (user room), which uses a singleton WS that
survives remounts.

Priority: transcript room WS > user room WS > REST API.
2026-02-09 15:09:29 -05:00
Igor Loskutov
b4ccbe6928 test: add WebSocket broadcast delivery tests for STATUS and DAG_STATUS
Exercises the full broadcast → pub/sub → WebSocket delivery chain
that DEBUG.md identified as potentially broken. Covers send_json
direct delivery, broadcast_event() end-to-end, and event ordering.
Also patches broadcast.py's get_ws_manager (missing from conftest).
2026-02-09 14:58:13 -05:00
Igor Loskutov
38f100a83e fix: invalidate individual transcript query on TRANSCRIPT_STATUS user event
The UserEventsProvider only invalidated the list query on status changes.
The detail page's useTranscriptGet was never refreshed, so it never
redirected to /processing on reprocess.
2026-02-09 14:28:11 -05:00
Igor Loskutov
faec509a33 fix: invalidate transcript query on STATUS websocket event
Without this, the page never redirects to /processing when a reprocess
changes status from ended->processing, because the redirect logic only
watches the REST query data, not the WebSocket status state.
2026-02-09 14:22:53 -05:00
Igor Loskutov
4d9f5fa4b4 test: remove impossible-scenario tests from DAG REST enrichment
Remove 6 tests that covered malformed data shapes (data=None,
data=string, data=list, non-dict event elements) that cannot
occur since our own code always writes well-formed dicts via
model_dump(mode="json").
2026-02-09 14:10:19 -05:00
Igor Loskutov
455cb3d099 fix: use mode="json" in add_event to serialize datetimes in event data
Prevents 'Object of type datetime is not JSON serializable' when
broadcasting DAG_STATUS events to user room via WebSocket.
2026-02-09 14:06:49 -05:00
Igor Loskutov
2410688559 fix: pass DagStatusData model instead of dict to append_event_and_broadcast
add_event() calls .model_dump() on data, so it needs a Pydantic model not a dict.
2026-02-09 14:00:43 -05:00
Igor Loskutov
6dd96bfa5e test: enhance Task 4 REST enrichment tests — malformed data, GET extraction, search integration 2026-02-09 13:56:04 -05:00
Igor Loskutov
0acaa0de93 test: add with_error_handling decorator tests for broadcast integration 2026-02-09 13:56:04 -05:00
Igor Loskutov
c45d3182ee test: enhance Task 1+3 tests — production DAG, throttling
Add 4 new tests to test_dag_progress.py:

- test_production_dag_shape: Real 15-task pipeline topology with mixed
  statuses, verifying all tasks present, topological order invariant,
  and correct parent relationships (e.g. finalize has 4 parents)
- test_topological_sort_invariant_complex_dag: 7-node DAG with wide
  branching/merging to stress-test that all parents precede children
- test_logging_throttled_by_interval: Mocks time.monotonic to verify
  ctx.log() is throttled by interval while broadcasts are not
- test_uses_broadcast_event_not_append_event_and_broadcast: Verifies
  progress uses transient broadcast_event, not persisted append variant
2026-02-09 13:56:04 -05:00
Igor Loskutov
0c06cdd117 fix: consolidate DagTask types, fix REST fallback shape, fix lint noqa
- Extract shared DagTask/DagTaskStatus types into www/app/lib/dagTypes.ts
- Re-export from useWebSockets.ts and UserEventsProvider.tsx
- Fix browse page REST fallback: dag_status is list[dict] directly, not {tasks: [...]}
- Add missing # noqa: PLC0415 for fork-safe deferred imports
2026-02-09 13:25:40 -05:00
Igor Loskutov
ebae9124b6 feat: add DAG progress dots to browse page via WebSocket events
- Add TRANSCRIPT_DAG_STATUS handler to UserEventsProvider with
  DagStatusContext and useDagStatusMap hook for live DAG task updates
- Clean up dagStatusMap entries when TRANSCRIPT_STATUS transitions
  away from "processing"
- Create DagProgressDots component rendering color-coded dots per
  DAG task (green=completed, blue pulsing=running, hollow=queued,
  red=failed, gray=cancelled) with humanized tooltip names
- Wire dagStatusMap through browse page -> TranscriptCards ->
  TranscriptStatusIcon, falling back to REST dag_status field
2026-02-09 13:22:29 -05:00
Igor Loskutov
a6a5d35e44 feat: add DAG progress WebSocket handlers and processing page table
Add DAG_STATUS and DAG_TASK_PROGRESS event handlers to useWebSockets
hook with exported DagTask/DagTaskStatus types. Create DagProgressTable
component with status icons, live elapsed timers, progress bars, and
expandable error rows. Wire into processing page with REST fallback.
2026-02-09 13:22:25 -05:00
Igor Loskutov
025e6da539 feat: add dag_status REST enrichment to search and transcript GET 2026-02-09 13:12:05 -05:00
Igor Loskutov
4b79b0c989 feat: add broadcast_dag_status, decorator integration, and mixdown progress
- Add broadcast_dag_status() to dag_progress.py: fetches Hatchet run
  details, transforms to DagStatusData, and broadcasts DAG_STATUS event
  via WebSocket. Fire-and-forget with exception swallowing.
- Modify with_error_handling decorator to call broadcast_dag_status on
  both task success and failure.
- Add DAG_STATUS to USER_ROOM_EVENTS (broadcast.py) and reconnect
  filter (transcripts_websocket.py) to avoid replaying stale DAG state.
- Add initial DAG broadcast at workflow dispatch (transcript_process.py).
- Extend make_audio_progress_logger with optional transcript_id param
  for transient DAG_TASK_PROGRESS events during mixdown.
- All deferred imports for fork-safety, all broadcasts fire-and-forget.
2026-02-09 13:12:01 -05:00
Igor Loskutov
a359c845ff feat: add DagTask models and extract_dag_tasks transform
Foundation for DAG progress reporting to frontend. Ported topo sort
and task extraction from render_hatchet_run.py (Zulip worktree) to
produce structured Pydantic models instead of markdown.
2026-02-09 12:50:53 -05:00
21 changed files with 2665 additions and 20 deletions

View File

@@ -1,6 +1,7 @@
"""Search functionality for transcripts and other entities."""
import itertools
import json
from dataclasses import dataclass
from datetime import datetime
from io import StringIO
@@ -172,6 +173,9 @@ class SearchResult(BaseModel):
total_match_count: NonNegativeInt = Field(
default=0, description="Total number of matches found in the transcript"
)
dag_status: list[dict] | None = Field(
default=None, description="Latest DAG task status for processing transcripts"
)
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
@@ -328,6 +332,42 @@ class SnippetGenerator:
return summary_snippets + webvtt_snippets, total_matches
async def _fetch_dag_statuses(transcript_ids: list[str]) -> dict[str, list[dict]]:
"""Fetch latest DAG_STATUS event data for given transcript IDs.
Returns dict mapping transcript_id -> tasks list from the last DAG_STATUS event.
"""
if not transcript_ids:
return {}
db = get_database()
query = sqlalchemy.select(
[
transcripts.c.id,
transcripts.c.events,
]
).where(transcripts.c.id.in_(transcript_ids))
rows = await db.fetch_all(query)
result: dict[str, list[dict]] = {}
for row in rows:
events_raw = row["events"]
if not events_raw:
continue
# events is stored as JSON list
events = events_raw if isinstance(events_raw, list) else json.loads(events_raw)
# Find last DAG_STATUS event
for ev in reversed(events):
if isinstance(ev, dict) and ev.get("event") == "DAG_STATUS":
tasks = ev.get("data", {}).get("tasks")
if tasks:
result[row["id"]] = tasks
break
return result
class SearchController:
"""Controller for search operations across different entities."""
@@ -470,6 +510,14 @@ class SearchController:
logger.error(f"Error processing search results: {e}", exc_info=True)
raise
# Enrich processing transcripts with DAG status
processing_ids = [r.id for r in results if r.status == "processing"]
if processing_ids:
dag_statuses = await _fetch_dag_statuses(processing_ids)
for r in results:
if r.id in dag_statuses:
r.dag_status = dag_statuses[r.id]
return results, total

View File

@@ -234,7 +234,7 @@ class Transcript(BaseModel):
return dt.isoformat()
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
ev = TranscriptEvent(event=event, data=data.model_dump(mode="json"))
self.events.append(ev)
return ev

View File

@@ -15,7 +15,7 @@ from reflector.utils.string import NonEmptyString
from reflector.ws_manager import get_ws_manager
# Events that should also be sent to user room (matches Celery behavior)
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION", "DAG_STATUS"}
async def broadcast_event(

View File

@@ -0,0 +1,230 @@
"""
DAG Progress Reporting — models and transform.
Converts Hatchet V1WorkflowRunDetails into structured DagTask list
for frontend WebSocket/REST consumption.
Ported from render_hatchet_run.py (feat-dag-zulip) which renders markdown;
this module produces structured Pydantic models instead.
"""
from datetime import datetime
from enum import StrEnum
from hatchet_sdk.clients.rest.models import (
V1TaskStatus,
V1WorkflowRunDetails,
WorkflowRunShapeItemForWorkflowRunDetails,
)
from pydantic import BaseModel
class DagTaskStatus(StrEnum):
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
_HATCHET_TO_DAG_STATUS: dict[V1TaskStatus, DagTaskStatus] = {
V1TaskStatus.QUEUED: DagTaskStatus.QUEUED,
V1TaskStatus.RUNNING: DagTaskStatus.RUNNING,
V1TaskStatus.COMPLETED: DagTaskStatus.COMPLETED,
V1TaskStatus.FAILED: DagTaskStatus.FAILED,
V1TaskStatus.CANCELLED: DagTaskStatus.CANCELLED,
}
class DagTask(BaseModel):
name: str
status: DagTaskStatus
started_at: datetime | None
finished_at: datetime | None
duration_seconds: float | None
parents: list[str]
error: str | None
children_total: int | None
children_completed: int | None
progress_pct: float | None
class DagStatusData(BaseModel):
workflow_run_id: str
tasks: list[DagTask]
def _topo_sort(
shape: list[WorkflowRunShapeItemForWorkflowRunDetails],
) -> list[str]:
"""Topological sort of step_ids from shape DAG (Kahn's algorithm).
Ported from render_hatchet_run.py.
"""
step_ids = {s.step_id for s in shape}
children_map: dict[str, list[str]] = {}
in_degree: dict[str, int] = {sid: 0 for sid in step_ids}
for s in shape:
children = [c for c in (s.children_step_ids or []) if c in step_ids]
children_map[s.step_id] = children
for c in children:
in_degree[c] += 1
queue = sorted(sid for sid, deg in in_degree.items() if deg == 0)
result: list[str] = []
while queue:
node = queue.pop(0)
result.append(node)
for c in children_map.get(node, []):
in_degree[c] -= 1
if in_degree[c] == 0:
queue.append(c)
queue.sort()
return result
def _extract_error_summary(error_message: str | None) -> str | None:
"""Extract first meaningful line from error message, skipping traceback frames."""
if not error_message or not error_message.strip():
return None
err_lines = error_message.strip().split("\n")
err_summary = err_lines[0]
for line in err_lines:
stripped = line.strip()
if stripped and not stripped.startswith(("Traceback", "File ", "{", ")")):
err_summary = stripped
return err_summary
def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]:
"""Extract structured DagTask list from Hatchet workflow run details.
Returns tasks in topological order with status, timestamps, parents,
error summaries, and fan-out children counts.
"""
shape = details.shape or []
tasks = details.tasks or []
if not shape:
return []
# Build lookups
step_to_shape: dict[str, WorkflowRunShapeItemForWorkflowRunDetails] = {
s.step_id: s for s in shape
}
step_to_name: dict[str, str] = {s.step_id: s.task_name for s in shape}
# Reverse edges: child -> parent names
parents_by_step: dict[str, list[str]] = {s.step_id: [] for s in shape}
for s in shape:
for child_id in s.children_step_ids or []:
if child_id in parents_by_step:
parents_by_step[child_id].append(step_to_name[s.step_id])
# Join tasks by step_id
from hatchet_sdk.clients.rest.models import V1TaskSummary # noqa: PLC0415
task_by_step: dict[str, V1TaskSummary] = {}
for t in tasks:
if t.step_id and t.step_id in step_to_name:
task_by_step[t.step_id] = t
ordered = _topo_sort(shape)
result: list[DagTask] = []
for step_id in ordered:
name = step_to_name[step_id]
t = task_by_step.get(step_id)
if not t:
result.append(
DagTask(
name=name,
status=DagTaskStatus.QUEUED,
started_at=None,
finished_at=None,
duration_seconds=None,
parents=parents_by_step.get(step_id, []),
error=None,
children_total=None,
children_completed=None,
progress_pct=None,
)
)
continue
status = _HATCHET_TO_DAG_STATUS.get(t.status, DagTaskStatus.QUEUED)
duration_seconds: float | None = None
if t.duration is not None:
duration_seconds = t.duration / 1000.0
# Fan-out children
children_total: int | None = None
children_completed: int | None = None
if t.num_spawned_children and t.num_spawned_children > 0:
children_total = t.num_spawned_children
children_completed = sum(
1 for c in (t.children or []) if c.status == V1TaskStatus.COMPLETED
)
result.append(
DagTask(
name=name,
status=status,
started_at=t.started_at,
finished_at=t.finished_at,
duration_seconds=duration_seconds,
parents=parents_by_step.get(step_id, []),
error=_extract_error_summary(t.error_message),
children_total=children_total,
children_completed=children_completed,
progress_pct=None,
)
)
return result
async def broadcast_dag_status(transcript_id: str, workflow_run_id: str) -> None:
"""Fetch current DAG state from Hatchet and broadcast via WebSocket.
Fire-and-forget: exceptions are logged but never raised.
All imports are deferred for fork-safety (Hatchet workers fork processes).
"""
try:
from reflector.db.transcripts import transcripts_controller # noqa: I001, PLC0415
from reflector.hatchet.broadcast import append_event_and_broadcast # noqa: PLC0415
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
from reflector.hatchet.workflows.daily_multitrack_pipeline import ( # noqa: PLC0415
fresh_db_connection,
)
from reflector.logger import logger # noqa: PLC0415
async with fresh_db_connection():
client = HatchetClientManager.get_client()
details = await client.runs.aio_get(workflow_run_id)
dag_tasks = extract_dag_tasks(details)
dag_status = DagStatusData(workflow_run_id=workflow_run_id, tasks=dag_tasks)
transcript = await transcripts_controller.get_by_id(transcript_id)
if transcript:
await append_event_and_broadcast(
transcript_id,
transcript,
"DAG_STATUS",
dag_status,
logger,
)
except Exception:
from reflector.logger import logger # noqa: PLC0415
logger.warning(
"[DAG Progress] Failed to broadcast DAG status",
transcript_id=transcript_id,
workflow_run_id=workflow_run_id,
exc_info=True,
)

View File

@@ -184,7 +184,10 @@ class Loggable(Protocol):
def make_audio_progress_logger(
ctx: Loggable, task_name: TaskName, interval: float = 5.0
ctx: Loggable,
task_name: TaskName,
interval: float = 5.0,
transcript_id: str | None = None,
) -> Callable[[float | None, float], None]:
"""Create a throttled progress logger callback for audio processing.
@@ -192,6 +195,7 @@ def make_audio_progress_logger(
ctx: Object with .log() method (e.g., Hatchet Context).
task_name: Name to prefix in log messages.
interval: Minimum seconds between log messages.
transcript_id: If provided, broadcasts transient DAG_TASK_PROGRESS events.
Returns:
Callback(progress_pct, audio_position) that logs at most every `interval` seconds.
@@ -213,6 +217,27 @@ def make_audio_progress_logger(
)
last_log_time[0] = now
if transcript_id and progress_pct is not None:
try:
import asyncio # noqa: PLC0415
from reflector.db.transcripts import TranscriptEvent # noqa: PLC0415
from reflector.hatchet.broadcast import broadcast_event # noqa: PLC0415
loop = asyncio.get_event_loop()
loop.create_task(
broadcast_event(
transcript_id,
TranscriptEvent(
event="DAG_TASK_PROGRESS",
data={"task_name": task_name, "progress_pct": progress_pct},
),
logger=logger,
)
)
except Exception:
pass # transient, never fail the callback
return callback
@@ -237,8 +262,15 @@ def with_error_handling(
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(input: PipelineInput, ctx: Context) -> R:
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
try:
return await func(input, ctx)
result = await func(input, ctx)
try:
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
except Exception:
pass
return result
except Exception as e:
logger.error(
f"[Hatchet] {step_name} failed",
@@ -246,6 +278,10 @@ def with_error_handling(
error=str(e),
exc_info=True,
)
try:
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
except Exception:
pass
if set_error_status:
await set_workflow_error_status(input.transcript_id)
raise
@@ -560,7 +596,9 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
target_sample_rate,
offsets_seconds=None,
logger=logger,
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
progress_callback=make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, transcript_id=input.transcript_id
),
expected_duration_sec=recording_duration if recording_duration > 0 else None,
)
await writer.flush()

View File

@@ -267,6 +267,19 @@ async def dispatch_transcript_processing(
)
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
try:
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
await broadcast_dag_status(config.transcript_id, workflow_id)
except Exception:
logger.warning(
"[DAG Progress] Failed initial broadcast",
transcript_id=config.transcript_id,
workflow_id=workflow_id,
exc_info=True,
)
return None
elif isinstance(config, FileProcessingConfig):

View File

@@ -111,6 +111,7 @@ class GetTranscriptMinimal(BaseModel):
room_id: str | None = None
room_name: str | None = None
audio_deleted: bool | None = None
dag_status: list[dict] | None = None
class TranscriptParticipantWithEmail(TranscriptParticipant):
@@ -491,6 +492,13 @@ async def transcript_get(
)
)
dag_status = None
if transcript.status == "processing" and transcript.events:
for ev in reversed(transcript.events):
if ev.event == "DAG_STATUS":
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
break
base_data = {
"id": transcript.id,
"user_id": transcript.user_id,
@@ -512,6 +520,7 @@ async def transcript_get(
"room_id": transcript.room_id,
"room_name": room_name,
"audio_deleted": transcript.audio_deleted,
"dag_status": dag_status,
"participants": participants,
}

View File

@@ -41,13 +41,19 @@ async def transcript_events_websocket(
try:
# on first connection, send all events only to the current user
# Find the last DAG_STATUS to send after other historical events
last_dag_status = None
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
if name == "DAG_STATUS":
last_dag_status = event
continue
await websocket.send_json(event.model_dump(mode="json"))
# Send only the most recent DAG_STATUS so reconnecting clients get current state
if last_dag_status is not None:
await websocket.send_json(last_dag_status.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection

View File

@@ -0,0 +1,959 @@
"""Tests for DAG progress models and transform function.
Tests the extract_dag_tasks function that converts Hatchet V1WorkflowRunDetails
into structured DagTask list for frontend consumption.
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from reflector.hatchet.constants import TaskName
from reflector.hatchet.dag_progress import (
DagStatusData,
DagTask,
DagTaskStatus,
extract_dag_tasks,
)
def _make_shape_item(
step_id: str,
task_name: str,
children_step_ids: list[str] | None = None,
) -> MagicMock:
"""Create a mock WorkflowRunShapeItemForWorkflowRunDetails."""
item = MagicMock()
item.step_id = step_id
item.task_name = task_name
item.children_step_ids = children_step_ids or []
return item
def _make_task_summary(
step_id: str,
status: str = "QUEUED",
started_at: datetime | None = None,
finished_at: datetime | None = None,
duration: int | None = None,
error_message: str | None = None,
task_external_id: str | None = None,
num_spawned_children: int | None = None,
children: list | None = None,
) -> MagicMock:
"""Create a mock V1TaskSummary."""
from hatchet_sdk.clients.rest.models import V1TaskStatus
task = MagicMock()
task.step_id = step_id
task.status = V1TaskStatus(status)
task.started_at = started_at
task.finished_at = finished_at
task.duration = duration
task.error_message = error_message
task.task_external_id = task_external_id or f"ext-{step_id}"
task.num_spawned_children = num_spawned_children
task.children = children or []
return task
def _make_details(
shape: list,
tasks: list,
run_id: str = "test-run-id",
) -> MagicMock:
"""Create a mock V1WorkflowRunDetails."""
details = MagicMock()
details.shape = shape
details.tasks = tasks
details.task_events = []
details.run = MagicMock()
details.run.metadata = MagicMock()
details.run.metadata.id = run_id
return details
class TestExtractDagTasksBasic:
"""Test basic extraction of DAG tasks from workflow run details."""
def test_empty_shape_returns_empty_list(self):
details = _make_details(shape=[], tasks=[])
result = extract_dag_tasks(details)
assert result == []
def test_single_task_queued(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [_make_task_summary("s1", status="QUEUED")]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert len(result) == 1
assert result[0].name == "get_recording"
assert result[0].status == DagTaskStatus.QUEUED
assert result[0].parents == []
assert result[0].started_at is None
assert result[0].finished_at is None
assert result[0].duration_seconds is None
assert result[0].error is None
assert result[0].children_total is None
assert result[0].children_completed is None
assert result[0].progress_pct is None
def test_completed_task_with_duration(self):
now = datetime.now(timezone.utc)
shape = [_make_shape_item("s1", "get_recording")]
tasks = [
_make_task_summary(
"s1",
status="COMPLETED",
started_at=now,
finished_at=now,
duration=1500, # milliseconds
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.COMPLETED
assert result[0].duration_seconds == 1.5
assert result[0].started_at == now
assert result[0].finished_at == now
def test_failed_task_with_error(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [
_make_task_summary(
"s1",
status="FAILED",
error_message="Traceback (most recent call last):\n File something\nConnectionError: connection refused",
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.FAILED
assert result[0].error == "ConnectionError: connection refused"
def test_running_task(self):
now = datetime.now(timezone.utc)
shape = [_make_shape_item("s1", "mixdown_tracks")]
tasks = [
_make_task_summary(
"s1",
status="RUNNING",
started_at=now,
duration=5000,
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.RUNNING
assert result[0].started_at == now
assert result[0].duration_seconds == 5.0
def test_cancelled_task(self):
shape = [_make_shape_item("s1", "post_zulip")]
tasks = [_make_task_summary("s1", status="CANCELLED")]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].status == DagTaskStatus.CANCELLED
class TestExtractDagTasksTopology:
"""Test topological ordering and parent extraction."""
def test_linear_chain_parents(self):
"""A -> B -> C should produce correct parents."""
shape = [
_make_shape_item("s1", "get_recording", children_step_ids=["s2"]),
_make_shape_item("s2", "get_participants", children_step_ids=["s3"]),
_make_shape_item("s3", "process_tracks"),
]
tasks = [
_make_task_summary("s1", status="COMPLETED"),
_make_task_summary("s2", status="COMPLETED"),
_make_task_summary("s3", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert [t.name for t in result] == [
"get_recording",
"get_participants",
"process_tracks",
]
assert result[0].parents == []
assert result[1].parents == ["get_recording"]
assert result[2].parents == ["get_participants"]
def test_diamond_dag(self):
"""
A -> B, A -> C, B -> D, C -> D
D should have parents [B, C] (or [C, B] depending on sort).
"""
shape = [
_make_shape_item("s1", "get_recording", children_step_ids=["s2", "s3"]),
_make_shape_item("s2", "mixdown_tracks", children_step_ids=["s4"]),
_make_shape_item("s3", "detect_topics", children_step_ids=["s4"]),
_make_shape_item("s4", "finalize"),
]
tasks = [
_make_task_summary("s1", status="COMPLETED"),
_make_task_summary("s2", status="RUNNING"),
_make_task_summary("s3", status="RUNNING"),
_make_task_summary("s4", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
# Topological: s1 first, s2/s3 in some order, s4 last
assert result[0].name == "get_recording"
assert result[-1].name == "finalize"
finalize = result[-1]
assert set(finalize.parents) == {"mixdown_tracks", "detect_topics"}
def test_topological_order_is_stable(self):
"""Verify deterministic ordering (sorted queue in Kahn's)."""
shape = [
_make_shape_item("s_c", "task_c"),
_make_shape_item("s_a", "task_a", children_step_ids=["s_c"]),
_make_shape_item("s_b", "task_b", children_step_ids=["s_c"]),
]
tasks = [
_make_task_summary("s_c", status="QUEUED"),
_make_task_summary("s_a", status="COMPLETED"),
_make_task_summary("s_b", status="COMPLETED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
# s_a and s_b both roots with in-degree 0; sorted alphabetically by step_id
names = [t.name for t in result]
assert names[-1] == "task_c"
# First two should be task_a, task_b (sorted by step_id: s_a < s_b)
assert names[0] == "task_a"
assert names[1] == "task_b"
def test_production_dag_shape(self):
"""Test the real 15-task pipeline topology with mixed statuses.
Simulates a mid-pipeline state where early tasks completed,
middle tasks running, and later tasks still queued.
"""
# Production DAG edges (parent -> children):
# get_recording -> get_participants
# get_participants -> process_tracks
# process_tracks -> mixdown_tracks, detect_topics, finalize
# mixdown_tracks -> generate_waveform
# detect_topics -> generate_title, extract_subjects
# extract_subjects -> process_subjects, identify_action_items
# process_subjects -> generate_recap
# generate_title -> finalize
# generate_recap -> finalize
# identify_action_items -> finalize
# finalize -> cleanup_consent
# cleanup_consent -> post_zulip, send_webhook
shape = [
_make_shape_item(
"s_get_recording", TaskName.GET_RECORDING, ["s_get_participants"]
),
_make_shape_item(
"s_get_participants", TaskName.GET_PARTICIPANTS, ["s_process_tracks"]
),
_make_shape_item(
"s_process_tracks",
TaskName.PROCESS_TRACKS,
["s_mixdown_tracks", "s_detect_topics", "s_finalize"],
),
_make_shape_item(
"s_mixdown_tracks", TaskName.MIXDOWN_TRACKS, ["s_generate_waveform"]
),
_make_shape_item("s_generate_waveform", TaskName.GENERATE_WAVEFORM),
_make_shape_item(
"s_detect_topics",
TaskName.DETECT_TOPICS,
["s_generate_title", "s_extract_subjects"],
),
_make_shape_item(
"s_generate_title", TaskName.GENERATE_TITLE, ["s_finalize"]
),
_make_shape_item(
"s_extract_subjects",
TaskName.EXTRACT_SUBJECTS,
["s_process_subjects", "s_identify_action_items"],
),
_make_shape_item(
"s_process_subjects", TaskName.PROCESS_SUBJECTS, ["s_generate_recap"]
),
_make_shape_item(
"s_generate_recap", TaskName.GENERATE_RECAP, ["s_finalize"]
),
_make_shape_item(
"s_identify_action_items",
TaskName.IDENTIFY_ACTION_ITEMS,
["s_finalize"],
),
_make_shape_item("s_finalize", TaskName.FINALIZE, ["s_cleanup_consent"]),
_make_shape_item(
"s_cleanup_consent",
TaskName.CLEANUP_CONSENT,
["s_post_zulip", "s_send_webhook"],
),
_make_shape_item("s_post_zulip", TaskName.POST_ZULIP),
_make_shape_item("s_send_webhook", TaskName.SEND_WEBHOOK),
]
# Mid-pipeline: early tasks done, middle running, later queued
tasks = [
_make_task_summary("s_get_recording", status="COMPLETED"),
_make_task_summary("s_get_participants", status="COMPLETED"),
_make_task_summary("s_process_tracks", status="COMPLETED"),
_make_task_summary("s_mixdown_tracks", status="RUNNING"),
_make_task_summary("s_generate_waveform", status="QUEUED"),
_make_task_summary("s_detect_topics", status="RUNNING"),
_make_task_summary("s_generate_title", status="QUEUED"),
_make_task_summary("s_extract_subjects", status="QUEUED"),
_make_task_summary("s_process_subjects", status="QUEUED"),
_make_task_summary("s_generate_recap", status="QUEUED"),
_make_task_summary("s_identify_action_items", status="QUEUED"),
_make_task_summary("s_finalize", status="QUEUED"),
_make_task_summary("s_cleanup_consent", status="QUEUED"),
_make_task_summary("s_post_zulip", status="QUEUED"),
_make_task_summary("s_send_webhook", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
# All 15 tasks present
assert len(result) == 15
result_names = [t.name for t in result]
assert set(result_names) == {
TaskName.GET_RECORDING,
TaskName.GET_PARTICIPANTS,
TaskName.PROCESS_TRACKS,
TaskName.MIXDOWN_TRACKS,
TaskName.GENERATE_WAVEFORM,
TaskName.DETECT_TOPICS,
TaskName.GENERATE_TITLE,
TaskName.EXTRACT_SUBJECTS,
TaskName.PROCESS_SUBJECTS,
TaskName.GENERATE_RECAP,
TaskName.IDENTIFY_ACTION_ITEMS,
TaskName.FINALIZE,
TaskName.CLEANUP_CONSENT,
TaskName.POST_ZULIP,
TaskName.SEND_WEBHOOK,
}
# Topological order invariant: no task appears before its parents
name_to_index = {t.name: i for i, t in enumerate(result)}
for task in result:
for parent_name in task.parents:
assert name_to_index[parent_name] < name_to_index[task.name], (
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
f"must appear before {task.name} (idx {name_to_index[task.name]})"
)
# finalize has exactly 4 parents
finalize = next(t for t in result if t.name == TaskName.FINALIZE)
assert set(finalize.parents) == {
TaskName.PROCESS_TRACKS,
TaskName.GENERATE_TITLE,
TaskName.GENERATE_RECAP,
TaskName.IDENTIFY_ACTION_ITEMS,
}
# cleanup_consent has 1 parent (finalize)
cleanup = next(t for t in result if t.name == TaskName.CLEANUP_CONSENT)
assert cleanup.parents == [TaskName.FINALIZE]
# post_zulip and send_webhook both have cleanup_consent as parent
post_zulip = next(t for t in result if t.name == TaskName.POST_ZULIP)
send_webhook = next(t for t in result if t.name == TaskName.SEND_WEBHOOK)
assert post_zulip.parents == [TaskName.CLEANUP_CONSENT]
assert send_webhook.parents == [TaskName.CLEANUP_CONSENT]
# Verify statuses propagated correctly
assert (
next(t for t in result if t.name == TaskName.GET_RECORDING).status
== DagTaskStatus.COMPLETED
)
assert (
next(t for t in result if t.name == TaskName.MIXDOWN_TRACKS).status
== DagTaskStatus.RUNNING
)
assert (
next(t for t in result if t.name == TaskName.FINALIZE).status
== DagTaskStatus.QUEUED
)
def test_topological_sort_invariant_complex_dag(self):
"""For a complex DAG, every task's parents appear earlier in the list.
Uses a wider branching/merging DAG than diamond to stress the invariant.
"""
# DAG: A -> B, A -> C, A -> D, B -> E, C -> E, C -> F, D -> F, E -> G, F -> G
shape = [
_make_shape_item("s_a", "task_a", ["s_b", "s_c", "s_d"]),
_make_shape_item("s_b", "task_b", ["s_e"]),
_make_shape_item("s_c", "task_c", ["s_e", "s_f"]),
_make_shape_item("s_d", "task_d", ["s_f"]),
_make_shape_item("s_e", "task_e", ["s_g"]),
_make_shape_item("s_f", "task_f", ["s_g"]),
_make_shape_item("s_g", "task_g"),
]
tasks = [
_make_task_summary("s_a", status="COMPLETED"),
_make_task_summary("s_b", status="COMPLETED"),
_make_task_summary("s_c", status="RUNNING"),
_make_task_summary("s_d", status="COMPLETED"),
_make_task_summary("s_e", status="QUEUED"),
_make_task_summary("s_f", status="QUEUED"),
_make_task_summary("s_g", status="QUEUED"),
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert len(result) == 7
name_to_index = {t.name: i for i, t in enumerate(result)}
# Verify invariant: every parent appears before its child
for task in result:
for parent_name in task.parents:
assert name_to_index[parent_name] < name_to_index[task.name], (
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
f"must appear before {task.name} (idx {name_to_index[task.name]})"
)
# task_g has 2 parents
task_g = next(t for t in result if t.name == "task_g")
assert set(task_g.parents) == {"task_e", "task_f"}
# task_e has 2 parents
task_e = next(t for t in result if t.name == "task_e")
assert set(task_e.parents) == {"task_b", "task_c"}
# task_a is root (first in topological order)
assert result[0].name == "task_a"
assert result[0].parents == []
class TestExtractDagTasksFanOut:
"""Test fan-out tasks with spawned children."""
def test_fan_out_children_counts(self):
from hatchet_sdk.clients.rest.models import V1TaskStatus
child_mocks = []
for status in ["COMPLETED", "COMPLETED", "RUNNING", "QUEUED"]:
child = MagicMock()
child.status = V1TaskStatus(status)
child_mocks.append(child)
shape = [_make_shape_item("s1", "process_tracks")]
tasks = [
_make_task_summary(
"s1",
status="RUNNING",
num_spawned_children=4,
children=child_mocks,
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].children_total == 4
assert result[0].children_completed == 2
def test_no_children_when_no_spawn(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [
_make_task_summary("s1", status="COMPLETED", num_spawned_children=None)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].children_total is None
assert result[0].children_completed is None
def test_zero_spawned_children(self):
shape = [_make_shape_item("s1", "process_tracks")]
tasks = [_make_task_summary("s1", status="COMPLETED", num_spawned_children=0)]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].children_total is None
assert result[0].children_completed is None
class TestExtractDagTasksErrorExtraction:
"""Test error message extraction logic."""
def test_simple_error(self):
shape = [_make_shape_item("s1", "mixdown_tracks")]
tasks = [
_make_task_summary(
"s1", status="FAILED", error_message="ValueError: no tracks"
)
]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error == "ValueError: no tracks"
def test_traceback_extracts_meaningful_line(self):
error = (
"Traceback (most recent call last):\n"
' File "/app/something.py", line 42\n'
"RuntimeError: out of memory"
)
shape = [_make_shape_item("s1", "mixdown_tracks")]
tasks = [_make_task_summary("s1", status="FAILED", error_message=error)]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error == "RuntimeError: out of memory"
def test_no_error_when_none(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [_make_task_summary("s1", status="COMPLETED", error_message=None)]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error is None
def test_empty_error_message(self):
shape = [_make_shape_item("s1", "get_recording")]
tasks = [_make_task_summary("s1", status="FAILED", error_message="")]
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert result[0].error is None
class TestExtractDagTasksMissingData:
"""Test edge cases with missing task data."""
def test_shape_without_matching_task(self):
"""Shape has a step but tasks list doesn't contain it."""
shape = [_make_shape_item("s1", "get_recording")]
tasks = [] # No matching task
details = _make_details(shape, tasks)
result = extract_dag_tasks(details)
assert len(result) == 1
assert result[0].name == "get_recording"
assert result[0].status == DagTaskStatus.QUEUED # default when no task data
assert result[0].started_at is None
def test_none_shape_returns_empty(self):
details = _make_details(shape=[], tasks=[])
details.shape = None
result = extract_dag_tasks(details)
assert result == []
class TestDagStatusData:
"""Test DagStatusData model serialization."""
def test_serialization(self):
task = DagTask(
name="get_recording",
status=DagTaskStatus.COMPLETED,
started_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
finished_at=datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc),
duration_seconds=1.0,
parents=[],
error=None,
children_total=None,
children_completed=None,
progress_pct=None,
)
data = DagStatusData(workflow_run_id="test-123", tasks=[task])
dumped = data.model_dump(mode="json")
assert dumped["workflow_run_id"] == "test-123"
assert len(dumped["tasks"]) == 1
assert dumped["tasks"][0]["name"] == "get_recording"
assert dumped["tasks"][0]["status"] == "completed"
assert dumped["tasks"][0]["duration_seconds"] == 1.0
class AsyncContextManager:
"""No-op async context manager for mocking fresh_db_connection."""
async def __aenter__(self):
return None
async def __aexit__(self, *args):
return None
class TestBroadcastDagStatus:
"""Test broadcast_dag_status function.
broadcast_dag_status uses deferred imports inside its function body.
We mock the source modules/objects before calling the function.
Importing daily_multitrack_pipeline triggers a cascade
(subject_processing -> HatchetClientManager.get_client at module level),
so we set _instance before the import to prevent real SDK init.
"""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
Module-level code in workflow files calls get_client() during import.
Setting _instance before import avoids ClientConfig validation.
"""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
@pytest.mark.asyncio
async def test_broadcasts_dag_status(self):
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
mock_transcript = MagicMock()
mock_transcript.id = "t-123"
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
run_id="wf-abc",
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=mock_transcript,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
mock_broadcast.assert_called_once()
call_args = mock_broadcast.call_args
assert call_args[0][0] == "t-123" # transcript_id
assert call_args[0][1] is mock_transcript # transcript
assert call_args[0][2] == "DAG_STATUS" # event_name
data = call_args[0][3]
assert isinstance(data, DagStatusData)
assert data.workflow_run_id == "wf-abc"
assert len(data.tasks) == 1
@pytest.mark.asyncio
async def test_swallows_exceptions(self):
"""broadcast_dag_status never raises even when internals fail."""
from reflector.hatchet.dag_progress import broadcast_dag_status
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
side_effect=RuntimeError("db exploded"),
):
# Should not raise
await broadcast_dag_status("t-123", "wf-abc")
@pytest.mark.asyncio
async def test_no_broadcast_when_transcript_not_found(self):
"""broadcast_dag_status does not broadcast if transcript is None."""
mock_details = _make_details(
shape=[_make_shape_item("s1", "get_recording")],
tasks=[_make_task_summary("s1", status="COMPLETED")],
)
mock_client = MagicMock()
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
with (
patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
return_value=AsyncContextManager(),
),
patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=None,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
new_callable=AsyncMock,
) as mock_broadcast,
):
from reflector.hatchet.dag_progress import broadcast_dag_status
await broadcast_dag_status("t-123", "wf-abc")
mock_broadcast.assert_not_called()
class TestMakeAudioProgressLoggerWithBroadcast:
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
if original is None:
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
def test_broadcasts_transient_progress_event(self):
"""When transcript_id provided and progress_pct not None, broadcasts event."""
import asyncio
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
try:
with (
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(50.0, 100.0)
# Run pending tasks
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
mock_broadcast.assert_called_once()
event_arg = mock_broadcast.call_args[0][1]
assert event_arg.event == "DAG_TASK_PROGRESS"
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
assert event_arg.data["progress_pct"] == 50.0
finally:
loop.close()
def test_no_broadcast_without_transcript_id(self):
"""When transcript_id is None, no broadcast happens."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
)
callback(50.0, 100.0)
mock_broadcast.assert_not_called()
def test_no_broadcast_when_progress_pct_is_none(self):
"""When progress_pct is None, no broadcast happens even with transcript_id."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
with patch(
"reflector.hatchet.broadcast.broadcast_event",
new_callable=AsyncMock,
) as mock_broadcast:
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(None, 100.0)
mock_broadcast.assert_not_called()
def test_logging_throttled_by_interval(self):
"""With interval=5.0, rapid calls only log once until interval elapses.
The throttle applies to ctx.log() calls. Broadcasts (fire-and-forget)
are not throttled — they occur every call when transcript_id + progress_pct set.
"""
import asyncio
import time as time_mod
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
# Controlled monotonic values for the 4 calls from make_audio_progress_logger:
# init (start_time, last_log_time), call1 (now), call2 (now), call3 (now)
# After those, fall back to real time.monotonic() for asyncio internals.
controlled_values = [100.0, 100.0, 101.0, 106.0]
call_index = [0]
real_monotonic = time_mod.monotonic
def mock_monotonic():
if call_index[0] < len(controlled_values):
val = controlled_values[call_index[0]]
call_index[0] += 1
return val
return real_monotonic()
try:
with (
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.time.monotonic",
side_effect=mock_monotonic,
),
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=5.0, transcript_id="t-123"
)
# Call 1 at t=100.0: 100.0 - 100.0 = 0.0 < 5.0 => no log
callback(25.0, 50.0)
assert ctx.log.call_count == 0
# Call 2 at t=101.0: 101.0 - 100.0 = 1.0 < 5.0 => no log
callback(50.0, 100.0)
assert ctx.log.call_count == 0
# Call 3 at t=106.0: 106.0 - 100.0 = 6.0 >= 5.0 => logs
callback(75.0, 150.0)
assert ctx.log.call_count == 1
# Run pending broadcast tasks
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
# Broadcasts happen on every call (not throttled) — 3 calls total
assert mock_broadcast.call_count == 3
finally:
loop.close()
def test_uses_broadcast_event_not_append_event_and_broadcast(self):
"""Progress events use broadcast_event (transient), not append_event_and_broadcast (persisted)."""
import asyncio
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
make_audio_progress_logger,
)
ctx = MagicMock()
ctx.log = MagicMock()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
mock_broadcast_event = AsyncMock()
mock_append = AsyncMock()
tasks_created = []
original_create_task = loop.create_task
def capture_create_task(coro):
task = original_create_task(coro)
tasks_created.append(task)
return task
try:
with (
patch(
"reflector.hatchet.broadcast.broadcast_event",
mock_broadcast_event,
),
patch(
"reflector.hatchet.broadcast.append_event_and_broadcast",
mock_append,
),
patch.object(loop, "create_task", side_effect=capture_create_task),
):
callback = make_audio_progress_logger(
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
)
callback(50.0, 100.0)
if tasks_created:
loop.run_until_complete(asyncio.gather(*tasks_created))
# broadcast_event (transient) IS called
mock_broadcast_event.assert_called_once()
# append_event_and_broadcast (persisted) is NOT called
mock_append.assert_not_called()
finally:
loop.close()

View File

@@ -0,0 +1,181 @@
"""Tests for with_error_handling decorator integration with broadcast_dag_status.
The decorator wraps each pipeline task and calls broadcast_dag_status on both
success and failure paths. These tests verify that integration rather than
testing broadcast_dag_status in isolation (which test_dag_progress.py covers).
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from reflector.hatchet.constants import TaskName
class TestWithErrorHandlingBroadcast:
"""Test with_error_handling decorator's integration with broadcast_dag_status."""
@pytest.fixture(autouse=True)
def _setup_hatchet_mock(self):
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
Module-level code in workflow files calls get_client() during import.
Setting _instance before import avoids ClientConfig validation.
"""
from reflector.hatchet.client import HatchetClientManager
original = HatchetClientManager._instance
HatchetClientManager._instance = MagicMock()
yield
HatchetClientManager._instance = original
def _make_input(self, transcript_id: str = "t-123") -> MagicMock:
"""Create a mock PipelineInput with transcript_id."""
inp = MagicMock()
inp.transcript_id = transcript_id
return inp
def _make_ctx(self, workflow_run_id: str = "wf-abc") -> MagicMock:
"""Create a mock Context with workflow_run_id."""
ctx = MagicMock()
ctx.workflow_run_id = workflow_run_id
return ctx
@pytest.mark.asyncio
async def test_calls_broadcast_on_success(self):
"""Decorator calls broadcast_dag_status once when task succeeds."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
with_error_handling,
)
inner = AsyncMock(return_value="ok")
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
with patch(
"reflector.hatchet.dag_progress.broadcast_dag_status",
new_callable=AsyncMock,
) as mock_broadcast:
result = await wrapped(self._make_input(), self._make_ctx())
assert result == "ok"
mock_broadcast.assert_called_once_with("t-123", "wf-abc")
@pytest.mark.asyncio
async def test_calls_broadcast_on_failure(self):
"""Decorator calls broadcast_dag_status once when task raises."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
with_error_handling,
)
inner = AsyncMock(side_effect=RuntimeError("boom"))
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
with (
patch(
"reflector.hatchet.dag_progress.broadcast_dag_status",
new_callable=AsyncMock,
) as mock_broadcast,
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
),
):
with pytest.raises(RuntimeError, match="boom"):
await wrapped(self._make_input(), self._make_ctx())
mock_broadcast.assert_called_once_with("t-123", "wf-abc")
@pytest.mark.asyncio
async def test_swallows_broadcast_exception_on_success(self):
"""Broadcast failure does not crash the task on the success path."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
with_error_handling,
)
inner = AsyncMock(return_value="ok")
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
with patch(
"reflector.hatchet.dag_progress.broadcast_dag_status",
new_callable=AsyncMock,
side_effect=RuntimeError("broadcast exploded"),
):
result = await wrapped(self._make_input(), self._make_ctx())
assert result == "ok"
@pytest.mark.asyncio
async def test_swallows_broadcast_exception_on_failure(self):
"""Original task exception propagates even when broadcast also fails."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
with_error_handling,
)
inner = AsyncMock(side_effect=ValueError("original error"))
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
with (
patch(
"reflector.hatchet.dag_progress.broadcast_dag_status",
new_callable=AsyncMock,
side_effect=RuntimeError("broadcast exploded"),
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
),
):
with pytest.raises(ValueError, match="original error"):
await wrapped(self._make_input(), self._make_ctx())
@pytest.mark.asyncio
async def test_calls_set_workflow_error_status_on_failure(self):
"""On task failure with set_error_status=True (default), calls set_workflow_error_status."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
with_error_handling,
)
inner = AsyncMock(side_effect=RuntimeError("boom"))
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
with (
patch(
"reflector.hatchet.dag_progress.broadcast_dag_status",
new_callable=AsyncMock,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error,
):
with pytest.raises(RuntimeError, match="boom"):
await wrapped(self._make_input(), self._make_ctx())
mock_set_error.assert_called_once_with("t-123")
@pytest.mark.asyncio
async def test_no_set_workflow_error_status_when_disabled(self):
"""With set_error_status=False, set_workflow_error_status is NOT called on failure."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
with_error_handling,
)
inner = AsyncMock(side_effect=RuntimeError("boom"))
wrapped = with_error_handling(TaskName.GET_RECORDING, set_error_status=False)(
inner
)
with (
patch(
"reflector.hatchet.dag_progress.broadcast_dag_status",
new_callable=AsyncMock,
),
patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error,
):
with pytest.raises(RuntimeError, match="boom"):
await wrapped(self._make_input(), self._make_ctx())
mock_set_error.assert_not_called()

View File

@@ -0,0 +1,421 @@
"""Tests for DAG status REST enrichment on search and transcript GET endpoints."""
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
import reflector.db.search as search_module
from reflector.db.search import SearchResult, _fetch_dag_statuses
from reflector.db.transcripts import TranscriptEvent
class TestFetchDagStatuses:
"""Test the _fetch_dag_statuses helper."""
@pytest.mark.asyncio
async def test_returns_empty_for_empty_ids(self):
result = await _fetch_dag_statuses([])
assert result == {}
@pytest.mark.asyncio
async def test_extracts_last_dag_status(self):
events = [
{"event": "STATUS", "data": {"value": "processing"}},
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [{"name": "get_recording", "status": "completed"}],
},
},
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [
{"name": "get_recording", "status": "completed"},
{"name": "process_tracks", "status": "running"},
],
},
},
]
mock_row = {"id": "t1", "events": events}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert "t1" in result
assert len(result["t1"]) == 2 # Last DAG_STATUS had 2 tasks
@pytest.mark.asyncio
async def test_skips_transcripts_without_events(self):
mock_row = {"id": "t1", "events": None}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert result == {}
@pytest.mark.asyncio
async def test_skips_transcripts_without_dag_status(self):
events = [
{"event": "STATUS", "data": {"value": "processing"}},
{"event": "DURATION", "data": {"duration": 1000}},
]
mock_row = {"id": "t1", "events": events}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert result == {}
@pytest.mark.asyncio
async def test_handles_json_string_events(self):
"""Events stored as JSON string rather than already-parsed list."""
import json
events = [
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [{"name": "transcribe", "status": "running"}],
},
},
]
mock_row = {"id": "t1", "events": json.dumps(events)}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert "t1" in result
assert len(result["t1"]) == 1
assert result["t1"][0]["name"] == "transcribe"
@pytest.mark.asyncio
async def test_multiple_transcripts(self):
"""Handles multiple transcripts in one call."""
events_t1 = [
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [{"name": "a", "status": "completed"}],
},
},
]
events_t2 = [
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r2",
"tasks": [{"name": "b", "status": "running"}],
},
},
]
mock_rows = [
{"id": "t1", "events": events_t1},
{"id": "t2", "events": events_t2},
]
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=mock_rows)
result = await _fetch_dag_statuses(["t1", "t2"])
assert "t1" in result
assert "t2" in result
assert result["t1"][0]["name"] == "a"
assert result["t2"][0]["name"] == "b"
@pytest.mark.asyncio
async def test_dag_status_without_tasks_key_skipped(self):
"""DAG_STATUS event with no tasks key in data should be skipped."""
events = [
{"event": "DAG_STATUS", "data": {"workflow_run_id": "r1"}},
]
mock_row = {"id": "t1", "events": events}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert result == {}
def _extract_dag_status_from_transcript(transcript):
"""Replicate the dag_status extraction logic from transcript_get view.
This mirrors the code in reflector/views/transcripts.py lines 495-500:
dag_status = None
if transcript.status == "processing" and transcript.events:
for ev in reversed(transcript.events):
if ev.event == "DAG_STATUS":
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
break
"""
dag_status = None
if transcript.status == "processing" and transcript.events:
for ev in reversed(transcript.events):
if ev.event == "DAG_STATUS":
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
break
return dag_status
class TestTranscriptGetDagStatusExtraction:
"""Test dag_status extraction logic from transcript_get endpoint.
The actual endpoint is complex to set up, so we test the extraction
logic directly using the same code pattern from the view.
"""
def test_processing_transcript_with_dag_status_events(self):
"""Processing transcript with DAG_STATUS events returns tasks from last event."""
transcript = SimpleNamespace(
status="processing",
events=[
TranscriptEvent(event="STATUS", data={"value": "processing"}),
TranscriptEvent(
event="DAG_STATUS",
data={
"workflow_run_id": "r1",
"tasks": [{"name": "get_recording", "status": "completed"}],
},
),
TranscriptEvent(
event="DAG_STATUS",
data={
"workflow_run_id": "r1",
"tasks": [
{"name": "get_recording", "status": "completed"},
{"name": "transcribe", "status": "running"},
],
},
),
],
)
result = _extract_dag_status_from_transcript(transcript)
assert result is not None
assert len(result) == 2
assert result[0]["name"] == "get_recording"
assert result[1]["name"] == "transcribe"
assert result[1]["status"] == "running"
def test_processing_transcript_without_dag_status_events(self):
"""Processing transcript with only non-DAG_STATUS events returns None."""
transcript = SimpleNamespace(
status="processing",
events=[
TranscriptEvent(event="STATUS", data={"value": "processing"}),
TranscriptEvent(event="DURATION", data={"duration": 1000}),
],
)
result = _extract_dag_status_from_transcript(transcript)
assert result is None
def test_ended_transcript_with_dag_status_events(self):
"""Ended transcript with DAG_STATUS events returns None (status check)."""
transcript = SimpleNamespace(
status="ended",
events=[
TranscriptEvent(
event="DAG_STATUS",
data={
"workflow_run_id": "r1",
"tasks": [{"name": "transcribe", "status": "completed"}],
},
),
],
)
result = _extract_dag_status_from_transcript(transcript)
assert result is None
def test_processing_transcript_with_empty_events(self):
"""Processing transcript with empty events list returns None."""
transcript = SimpleNamespace(
status="processing",
events=[],
)
result = _extract_dag_status_from_transcript(transcript)
assert result is None
def test_processing_transcript_with_none_events(self):
"""Processing transcript with None events returns None."""
transcript = SimpleNamespace(
status="processing",
events=None,
)
result = _extract_dag_status_from_transcript(transcript)
assert result is None
def test_extracts_last_dag_status_not_first(self):
"""Should pick the last DAG_STATUS event (most recent), not the first."""
transcript = SimpleNamespace(
status="processing",
events=[
TranscriptEvent(
event="DAG_STATUS",
data={
"workflow_run_id": "r1",
"tasks": [{"name": "a", "status": "running"}],
},
),
TranscriptEvent(event="STATUS", data={"value": "processing"}),
TranscriptEvent(
event="DAG_STATUS",
data={
"workflow_run_id": "r1",
"tasks": [
{"name": "a", "status": "completed"},
{"name": "b", "status": "running"},
],
},
),
],
)
result = _extract_dag_status_from_transcript(transcript)
assert len(result) == 2
assert result[0]["status"] == "completed"
assert result[1]["name"] == "b"
class TestSearchEnrichmentIntegration:
"""Test DAG status enrichment in search results.
The search function enriches processing transcripts with dag_status
by calling _fetch_dag_statuses for processing IDs and assigning results.
We test this enrichment logic by mocking _fetch_dag_statuses.
"""
def _make_search_result(self, id: str, status: str) -> SearchResult:
"""Create a minimal SearchResult for testing."""
return SearchResult(
id=id,
title=f"Transcript {id}",
user_id="u1",
room_id=None,
room_name=None,
source_kind="live",
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
status=status,
rank=1.0,
duration=60.0,
search_snippets=[],
total_match_count=0,
dag_status=None,
)
@pytest.mark.asyncio
async def test_processing_result_gets_dag_status(self):
"""SearchResult with status='processing' and matching DAG_STATUS events
gets dag_status populated."""
results = [self._make_search_result("t1", "processing")]
dag_tasks = [
{"name": "get_recording", "status": "completed"},
{"name": "transcribe", "status": "running"},
]
with patch.object(
search_module,
"_fetch_dag_statuses",
new_callable=AsyncMock,
return_value={"t1": dag_tasks},
) as mock_fetch:
# Replicate the enrichment logic from SearchController.search_transcripts
processing_ids = [r.id for r in results if r.status == "processing"]
if processing_ids:
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
for r in results:
if r.id in dag_statuses:
r.dag_status = dag_statuses[r.id]
mock_fetch.assert_called_once_with(["t1"])
assert results[0].dag_status == dag_tasks
@pytest.mark.asyncio
async def test_ended_result_does_not_trigger_fetch(self):
"""SearchResult with status='ended' does NOT trigger _fetch_dag_statuses."""
results = [self._make_search_result("t1", "ended")]
with patch.object(
search_module,
"_fetch_dag_statuses",
new_callable=AsyncMock,
return_value={},
) as mock_fetch:
processing_ids = [r.id for r in results if r.status == "processing"]
if processing_ids:
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
for r in results:
if r.id in dag_statuses:
r.dag_status = dag_statuses[r.id]
mock_fetch.assert_not_called()
assert results[0].dag_status is None
@pytest.mark.asyncio
async def test_mixed_processing_and_ended_results(self):
"""Only processing results get enriched; ended results stay None."""
results = [
self._make_search_result("t1", "processing"),
self._make_search_result("t2", "ended"),
self._make_search_result("t3", "processing"),
]
dag_tasks_t1 = [{"name": "transcribe", "status": "running"}]
dag_tasks_t3 = [{"name": "diarize", "status": "completed"}]
with patch.object(
search_module,
"_fetch_dag_statuses",
new_callable=AsyncMock,
return_value={"t1": dag_tasks_t1, "t3": dag_tasks_t3},
) as mock_fetch:
processing_ids = [r.id for r in results if r.status == "processing"]
if processing_ids:
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
for r in results:
if r.id in dag_statuses:
r.dag_status = dag_statuses[r.id]
mock_fetch.assert_called_once_with(["t1", "t3"])
assert results[0].dag_status == dag_tasks_t1
assert results[1].dag_status is None
assert results[2].dag_status == dag_tasks_t3
@pytest.mark.asyncio
async def test_processing_result_without_dag_events_stays_none(self):
"""Processing result with no DAG_STATUS events in DB stays dag_status=None."""
results = [self._make_search_result("t1", "processing")]
with patch.object(
search_module,
"_fetch_dag_statuses",
new_callable=AsyncMock,
return_value={},
) as mock_fetch:
processing_ids = [r.id for r in results if r.status == "processing"]
if processing_ids:
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
for r in results:
if r.id in dag_statuses:
r.dag_status = dag_statuses[r.id]
mock_fetch.assert_called_once_with(["t1"])
assert results[0].dag_status is None

View File

@@ -0,0 +1,331 @@
"""WebSocket broadcast delivery tests for STATUS and DAG_STATUS events.
Tests the full chain identified in DEBUG.md:
broadcast_event() → ws_manager.send_json() → Redis/in-memory pub/sub
→ _pubsub_data_reader() → socket.send_json() → WebSocket client
Covers:
1. STATUS event delivery to transcript room WS
2. DAG_STATUS event delivery to transcript room WS
3. Full broadcast_event() chain (requires broadcast.py patching)
4. _pubsub_data_reader resilience when a client disconnects
"""
import asyncio
import threading
import time
import pytest
from httpx import AsyncClient
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
@pytest.fixture
def appserver_ws_broadcast(setup_database, monkeypatch):
"""Start real uvicorn server for WebSocket broadcast tests.
Also patches broadcast.py's get_ws_manager (missing from conftest autouse fixture).
"""
# Patch broadcast.py's get_ws_manager — conftest.py misses this module.
# Without this, broadcast_event() creates a real Redis ws_manager.
import reflector.ws_manager as ws_mod
from reflector.app import app
from reflector.db import get_database
monkeypatch.setattr(
"reflector.hatchet.broadcast.get_ws_manager", ws_mod.get_ws_manager
)
host = "127.0.0.1"
port = 1259
server_started = threading.Event()
server_exception = None
server_instance = None
def run_server():
nonlocal server_exception, server_instance
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
server_started.wait(timeout=30)
if server_exception:
raise server_exception
time.sleep(0.5)
yield host, port
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=2.0)
from reflector.ws_manager import reset_ws_manager
reset_ws_manager()
async def _create_transcript(host: str, port: int, name: str) -> str:
"""Create a transcript via ASGI transport and return its ID."""
from reflector.app import app
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
resp = await ac.post("/transcripts", json={"name": name})
assert resp.status_code == 200, f"Failed to create transcript: {resp.text}"
return resp.json()["id"]
async def _drain_historical_events(ws, timeout: float = 0.5) -> list[dict]:
"""Read all historical events sent on WS connect (non-blocking drain)."""
events = []
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
try:
msg = await asyncio.wait_for(ws.receive_json(), timeout=0.1)
events.append(msg)
except (asyncio.TimeoutError, Exception):
break
return events
# ---------------------------------------------------------------------------
# Test 1: STATUS event delivery via ws_manager.send_json
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_transcript_ws_receives_status_via_send_json(appserver_ws_broadcast):
"""STATUS event published via ws_manager.send_json() arrives at transcript room WS."""
host, port = appserver_ws_broadcast
transcript_id = await _create_transcript(host, port, "Status send_json test")
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
async with aconnect_ws(ws_url) as ws:
await _drain_historical_events(ws)
import reflector.ws_manager as ws_mod
ws_manager = ws_mod.get_ws_manager()
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message={"event": "STATUS", "data": {"value": "processing"}},
)
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
assert msg["event"] == "STATUS"
assert msg["data"]["value"] == "processing"
# ---------------------------------------------------------------------------
# Test 2: DAG_STATUS event delivery via ws_manager.send_json
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_transcript_ws_receives_dag_status_via_send_json(appserver_ws_broadcast):
"""DAG_STATUS event published via ws_manager.send_json() arrives at transcript room WS."""
host, port = appserver_ws_broadcast
transcript_id = await _create_transcript(host, port, "DAG_STATUS send_json test")
dag_payload = {
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "test-run-123",
"tasks": [
{
"name": "get_recording",
"status": "completed",
"started_at": "2025-01-01T00:00:00Z",
"finished_at": "2025-01-01T00:00:05Z",
"duration_seconds": 5.0,
"parents": [],
"error": None,
"children_total": None,
"children_completed": None,
"progress_pct": None,
},
{
"name": "process_tracks",
"status": "running",
"started_at": "2025-01-01T00:00:05Z",
"finished_at": None,
"duration_seconds": None,
"parents": ["get_recording"],
"error": None,
"children_total": 3,
"children_completed": 1,
"progress_pct": 33.3,
},
],
},
}
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
async with aconnect_ws(ws_url) as ws:
await _drain_historical_events(ws)
import reflector.ws_manager as ws_mod
ws_manager = ws_mod.get_ws_manager()
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message=dag_payload,
)
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
assert msg["event"] == "DAG_STATUS"
assert msg["data"]["workflow_run_id"] == "test-run-123"
assert len(msg["data"]["tasks"]) == 2
assert msg["data"]["tasks"][0]["name"] == "get_recording"
assert msg["data"]["tasks"][0]["status"] == "completed"
assert msg["data"]["tasks"][1]["name"] == "process_tracks"
assert msg["data"]["tasks"][1]["children_completed"] == 1
# ---------------------------------------------------------------------------
# Test 3: Full broadcast_event() chain for STATUS
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_broadcast_event_delivers_status_to_transcript_ws(appserver_ws_broadcast):
"""broadcast_event() end-to-end: STATUS event reaches transcript room WS."""
host, port = appserver_ws_broadcast
transcript_id = await _create_transcript(host, port, "broadcast_event STATUS test")
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
async with aconnect_ws(ws_url) as ws:
await _drain_historical_events(ws)
from reflector.db.transcripts import TranscriptEvent
from reflector.hatchet.broadcast import broadcast_event
from reflector.logger import logger
log = logger.bind(transcript_id=transcript_id)
event = TranscriptEvent(event="STATUS", data={"value": "processing"})
await broadcast_event(transcript_id, event, logger=log)
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
assert msg["event"] == "STATUS"
assert msg["data"]["value"] == "processing"
# ---------------------------------------------------------------------------
# Test 4: Full broadcast_event() chain for DAG_STATUS
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_broadcast_event_delivers_dag_status_to_transcript_ws(
appserver_ws_broadcast,
):
"""broadcast_event() end-to-end: DAG_STATUS event reaches transcript room WS."""
host, port = appserver_ws_broadcast
transcript_id = await _create_transcript(host, port, "broadcast_event DAG test")
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
async with aconnect_ws(ws_url) as ws:
await _drain_historical_events(ws)
from reflector.db.transcripts import TranscriptEvent
from reflector.hatchet.broadcast import broadcast_event
from reflector.logger import logger
log = logger.bind(transcript_id=transcript_id)
event = TranscriptEvent(
event="DAG_STATUS",
data={
"workflow_run_id": "test-run-456",
"tasks": [
{
"name": "get_recording",
"status": "running",
"started_at": None,
"finished_at": None,
"duration_seconds": None,
"parents": [],
"error": None,
"children_total": None,
"children_completed": None,
"progress_pct": None,
}
],
},
)
await broadcast_event(transcript_id, event, logger=log)
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
assert msg["event"] == "DAG_STATUS"
assert msg["data"]["tasks"][0]["name"] == "get_recording"
# ---------------------------------------------------------------------------
# Test 5: Multiple rapid events arrive in order
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_multiple_events_arrive_in_order(appserver_ws_broadcast):
"""Multiple STATUS then DAG_STATUS events arrive in correct order."""
host, port = appserver_ws_broadcast
transcript_id = await _create_transcript(host, port, "ordering test")
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
async with aconnect_ws(ws_url) as ws:
await _drain_historical_events(ws)
import reflector.ws_manager as ws_mod
ws_manager = ws_mod.get_ws_manager()
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message={"event": "STATUS", "data": {"value": "processing"}},
)
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message={
"event": "DAG_STATUS",
"data": {"workflow_run_id": "r1", "tasks": []},
},
)
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message={
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [{"name": "a", "status": "running"}],
},
},
)
await ws_manager.send_json(
room_id=f"ts:{transcript_id}",
message={"event": "STATUS", "data": {"value": "ended"}},
)
msgs = []
for _ in range(4):
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
msgs.append(msg)
assert msgs[0]["event"] == "STATUS"
assert msgs[0]["data"]["value"] == "processing"
assert msgs[1]["event"] == "DAG_STATUS"
assert msgs[1]["data"]["tasks"] == []
assert msgs[2]["event"] == "DAG_STATUS"
assert len(msgs[2]["data"]["tasks"]) == 1
assert msgs[3]["event"] == "STATUS"
assert msgs[3]["data"]["value"] == "ended"

View File

@@ -0,0 +1,61 @@
import React from "react";
import { Box, Flex } from "@chakra-ui/react";
import type { DagTask } from "../../../lib/UserEventsProvider";
const pulseKeyframes = `
@keyframes dagDotPulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.3; }
}
`;
function humanizeTaskName(name: string): string {
return name
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
.join(" ");
}
function dotProps(status: DagTask["status"]): Record<string, unknown> {
switch (status) {
case "completed":
return { bg: "green.500" };
case "running":
return {
bg: "blue.500",
style: { animation: "dagDotPulse 1.5s ease-in-out infinite" },
};
case "failed":
return { bg: "red.500" };
case "cancelled":
return { bg: "gray.400" };
case "queued":
default:
return {
bg: "transparent",
border: "1px solid",
borderColor: "gray.400",
};
}
}
export default function DagProgressDots({ tasks }: { tasks: DagTask[] }) {
return (
<>
<style>{pulseKeyframes}</style>
<Flex gap="2px" alignItems="center" flexWrap="wrap">
{tasks.map((task) => (
<Box
key={task.name}
w="4px"
h="4px"
borderRadius="full"
flexShrink={0}
title={humanizeTaskName(task.name)}
{...dotProps(task.status)}
/>
))}
</Flex>
</>
);
}

View File

@@ -19,6 +19,7 @@ import {
generateTextFragment,
} from "../../../lib/textHighlight";
import type { components } from "../../../reflector-api";
import type { DagTask } from "../../../lib/UserEventsProvider";
type SearchResult = components["schemas"]["SearchResult"];
type SourceKind = components["schemas"]["SourceKind"];
@@ -29,6 +30,7 @@ interface TranscriptCardsProps {
isLoading?: boolean;
onDelete: (transcriptId: string) => void;
onReprocess: (transcriptId: string) => void;
dagStatusMap?: Map<string, DagTask[]>;
}
function highlightText(text: string, query: string): React.ReactNode {
@@ -102,11 +104,13 @@ function TranscriptCard({
query,
onDelete,
onReprocess,
dagStatusMap,
}: {
result: SearchResult;
query: string;
onDelete: (transcriptId: string) => void;
onReprocess: (transcriptId: string) => void;
dagStatusMap?: Map<string, DagTask[]>;
}) {
const [isExpanded, setIsExpanded] = useState(false);
@@ -137,7 +141,16 @@ function TranscriptCard({
<Box borderWidth={1} p={4} borderRadius="md" fontSize="sm">
<Flex justify="space-between" alignItems="flex-start" gap="2">
<Box>
<TranscriptStatusIcon status={result.status} />
<TranscriptStatusIcon
status={result.status}
dagStatus={
dagStatusMap?.get(result.id) ??
((result as Record<string, unknown>).dag_status as
| DagTask[]
| null) ??
null
}
/>
</Box>
<Box flex="1">
{/* Title with highlighting and text fragment for deep linking */}
@@ -284,6 +297,7 @@ export default function TranscriptCards({
isLoading,
onDelete,
onReprocess,
dagStatusMap,
}: TranscriptCardsProps) {
return (
<Box position="relative">
@@ -315,6 +329,7 @@ export default function TranscriptCards({
query={query}
onDelete={onDelete}
onReprocess={onReprocess}
dagStatusMap={dagStatusMap}
/>
))}
</Stack>

View File

@@ -8,13 +8,17 @@ import {
FaGear,
} from "react-icons/fa6";
import { TranscriptStatus } from "../../../lib/transcript";
import type { DagTask } from "../../../lib/UserEventsProvider";
import DagProgressDots from "./DagProgressDots";
interface TranscriptStatusIconProps {
status: TranscriptStatus;
dagStatus?: DagTask[] | null;
}
export default function TranscriptStatusIcon({
status,
dagStatus,
}: TranscriptStatusIconProps) {
switch (status) {
case "ended":
@@ -36,6 +40,9 @@ export default function TranscriptStatusIcon({
</Box>
);
case "processing":
if (dagStatus && dagStatus.length > 0) {
return <DagProgressDots tasks={dagStatus} />;
}
return (
<Box as="span" title="Processing in progress">
<Icon color="gray.500" as={FaGear} />

View File

@@ -43,6 +43,7 @@ import DeleteTranscriptDialog from "./_components/DeleteTranscriptDialog";
import { formatLocalDate } from "../../lib/time";
import { RECORD_A_MEETING_URL } from "../../api/urls";
import { useUserName } from "../../lib/useUserName";
import { useDagStatusMap } from "../../lib/UserEventsProvider";
const SEARCH_FORM_QUERY_INPUT_NAME = "query" as const;
@@ -273,6 +274,7 @@ export default function TranscriptBrowser() {
}, [JSON.stringify(searchFilters)]);
const userName = useUserName();
const dagStatusMap = useDagStatusMap();
const [deletionLoading, setDeletionLoading] = useState(false);
const cancelRef = React.useRef(null);
const [transcriptToDeleteId, setTranscriptToDeleteId] =
@@ -408,6 +410,7 @@ export default function TranscriptBrowser() {
isLoading={searchLoading}
onDelete={setTranscriptToDeleteId}
onReprocess={handleProcessTranscript}
dagStatusMap={dagStatusMap}
/>
{!searchLoading && results.length === 0 && (

View File

@@ -0,0 +1,190 @@
"use client";
import { useEffect, useState } from "react";
import { Table, Box, Icon, Spinner, Text, Badge } from "@chakra-ui/react";
import { FaCheck, FaXmark, FaClock, FaMinus } from "react-icons/fa6";
import type { DagTask, DagTaskStatus } from "../../useWebSockets";
function humanizeTaskName(name: string): string {
return name
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
.join(" ");
}
function formatDuration(seconds: number): string {
if (seconds < 60) {
return `${Math.round(seconds)}s`;
}
const minutes = Math.floor(seconds / 60);
const remainingSeconds = Math.round(seconds % 60);
return `${minutes}m ${remainingSeconds}s`;
}
function StatusIcon({ status }: { status: DagTaskStatus }) {
switch (status) {
case "completed":
return (
<Box as="span" title="Completed">
<Icon color="green.500" as={FaCheck} />
</Box>
);
case "running":
return <Spinner size="sm" color="blue.500" />;
case "failed":
return (
<Box as="span" title="Failed">
<Icon color="red.500" as={FaXmark} />
</Box>
);
case "queued":
return (
<Box as="span" title="Queued">
<Icon color="gray.400" as={FaClock} />
</Box>
);
case "cancelled":
return (
<Box as="span" title="Cancelled">
<Icon color="gray.400" as={FaMinus} />
</Box>
);
default:
return null;
}
}
function ElapsedTimer({ startedAt }: { startedAt: string }) {
const [elapsed, setElapsed] = useState<number>(() => {
return (Date.now() - new Date(startedAt).getTime()) / 1000;
});
useEffect(() => {
const interval = setInterval(() => {
setElapsed((Date.now() - new Date(startedAt).getTime()) / 1000);
}, 1000);
return () => clearInterval(interval);
}, [startedAt]);
return <Text fontSize="sm">{formatDuration(elapsed)}</Text>;
}
function DurationCell({ task }: { task: DagTask }) {
if (task.status === "completed" && task.duration_seconds !== null) {
return <Text fontSize="sm">{formatDuration(task.duration_seconds)}</Text>;
}
if (task.status === "running" && task.started_at) {
return <ElapsedTimer startedAt={task.started_at} />;
}
return (
<Text fontSize="sm" color="gray.400">
--
</Text>
);
}
function ProgressCell({ task }: { task: DagTask }) {
if (task.progress_pct === null && task.children_total === null) {
return null;
}
return (
<Box>
{task.progress_pct !== null && (
<Box
w="100%"
h="6px"
bg="gray.200"
borderRadius="full"
overflow="hidden"
>
<Box
h="100%"
w={`${Math.min(100, Math.max(0, task.progress_pct))}%`}
bg={task.status === "failed" ? "red.400" : "blue.400"}
borderRadius="full"
transition="width 0.3s ease"
/>
</Box>
)}
{task.children_total !== null && (
<Badge
size="sm"
colorPalette="gray"
mt={task.progress_pct !== null ? 1 : 0}
>
{task.children_completed ?? 0}/{task.children_total}
</Badge>
)}
</Box>
);
}
function TaskRow({ task }: { task: DagTask }) {
const [expanded, setExpanded] = useState(false);
const hasFailed = task.status === "failed" && task.error;
return (
<>
<Table.Row
cursor={hasFailed ? "pointer" : "default"}
onClick={hasFailed ? () => setExpanded((prev) => !prev) : undefined}
_hover={hasFailed ? { bg: "gray.50" } : undefined}
>
<Table.Cell>
<Text fontSize="sm" fontWeight="medium">
{humanizeTaskName(task.name)}
</Text>
</Table.Cell>
<Table.Cell>
<StatusIcon status={task.status} />
</Table.Cell>
<Table.Cell>
<DurationCell task={task} />
</Table.Cell>
<Table.Cell>
<ProgressCell task={task} />
</Table.Cell>
</Table.Row>
{hasFailed && expanded && (
<Table.Row>
<Table.Cell colSpan={4}>
<Box bg="red.50" p={3} borderRadius="md">
<Text fontSize="xs" color="red.700" whiteSpace="pre-wrap">
{task.error}
</Text>
</Box>
</Table.Cell>
</Table.Row>
)}
</>
);
}
export default function DagProgressTable({ tasks }: { tasks: DagTask[] }) {
return (
<Box w="100%" overflowX="auto">
<Table.Root size="sm">
<Table.Header>
<Table.Row>
<Table.ColumnHeader fontWeight="600">Task</Table.ColumnHeader>
<Table.ColumnHeader fontWeight="600" width="80px">
Status
</Table.ColumnHeader>
<Table.ColumnHeader fontWeight="600" width="100px">
Duration
</Table.ColumnHeader>
<Table.ColumnHeader fontWeight="600" width="140px">
Progress
</Table.ColumnHeader>
</Table.Row>
</Table.Header>
<Table.Body>
{tasks.map((task) => (
<TaskRow key={task.name} task={task} />
))}
</Table.Body>
</Table.Root>
</Box>
);
}

View File

@@ -11,6 +11,10 @@ import {
import { useRouter } from "next/navigation";
import { useTranscriptGet } from "../../../../lib/apiHooks";
import { parseNonEmptyString } from "../../../../lib/utils";
import { useWebSockets } from "../../useWebSockets";
import type { DagTask } from "../../useWebSockets";
import { useDagStatusMap } from "../../../../lib/UserEventsProvider";
import DagProgressTable from "./DagProgressTable";
type TranscriptProcessing = {
params: Promise<{
@@ -24,9 +28,21 @@ export default function TranscriptProcessing(details: TranscriptProcessing) {
const router = useRouter();
const transcript = useTranscriptGet(transcriptId);
const { status: wsStatus, dagStatus: wsDagStatus } =
useWebSockets(transcriptId);
const userDagStatusMap = useDagStatusMap();
const userDagStatus = userDagStatusMap.get(transcriptId) ?? null;
const restDagStatus: DagTask[] | null =
((transcript.data as Record<string, unknown>)?.dag_status as
| DagTask[]
| null) ?? null;
// Prefer transcript room WS (most granular), then user room WS, then REST
const dagStatus = wsDagStatus ?? userDagStatus ?? restDagStatus;
useEffect(() => {
const status = transcript.data?.status;
const status = wsStatus?.value ?? transcript.data?.status;
if (!status) return;
if (status === "ended" || status === "error") {
@@ -41,6 +57,7 @@ export default function TranscriptProcessing(details: TranscriptProcessing) {
router.replace(dest);
}
}, [
wsStatus?.value,
transcript.data?.status,
transcript.data?.source_kind,
router,
@@ -74,11 +91,29 @@ export default function TranscriptProcessing(details: TranscriptProcessing) {
w={{ base: "full", md: "container.xl" }}
>
<Center h={"full"} w="full">
<VStack gap={10} bg="gray.100" p={10} borderRadius="md" maxW="500px">
<Spinner size="xl" color="blue.500" />
<Heading size={"md"} textAlign="center">
Processing recording
</Heading>
<VStack
gap={10}
bg="gray.100"
p={10}
borderRadius="md"
maxW="600px"
w="full"
>
{dagStatus ? (
<>
<Heading size={"md"} textAlign="center">
Processing recording
</Heading>
<DagProgressTable tasks={dagStatus} />
</>
) : (
<>
<Spinner size="xl" color="blue.500" />
<Heading size={"md"} textAlign="center">
Processing recording
</Heading>
</>
)}
<Text color="gray.600" textAlign="center">
You can safely return to the library while your recording is being
processed.

View File

@@ -14,6 +14,9 @@ import {
} from "../../lib/apiHooks";
import { NonEmptyString } from "../../lib/utils";
import type { DagTask } from "../../lib/dagTypes";
export type { DagTask, DagTaskStatus } from "../../lib/dagTypes";
export type UseWebSockets = {
transcriptTextLive: string;
translateText: string;
@@ -24,6 +27,7 @@ export type UseWebSockets = {
status: Status | null;
waveform: AudioWaveform | null;
duration: number | null;
dagStatus: DagTask[] | null;
};
export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
@@ -40,6 +44,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "",
});
const [status, setStatus] = useState<Status | null>(null);
const [dagStatus, setDagStatus] = useState<DagTask[] | null>(null);
const { setError } = useError();
const queryClient = useQueryClient();
@@ -431,11 +436,31 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
);
}
setStatus(message.data);
invalidateTranscript(queryClient, transcriptId as NonEmptyString);
if (message.data.value === "ended") {
ws.close();
}
break;
case "DAG_STATUS":
if (message.data?.tasks) {
setDagStatus(message.data.tasks);
}
break;
case "DAG_TASK_PROGRESS":
if (message.data) {
setDagStatus(
(prev) =>
prev?.map((t) =>
t.name === message.data.task_name
? { ...t, progress_pct: message.data.progress_pct }
: t,
) ?? null,
);
}
break;
default:
setError(
new Error(`Received unknown WebSocket event: ${message.event}`),
@@ -493,5 +518,6 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
status,
waveform,
duration,
dagStatus,
};
};

View File

@@ -1,11 +1,25 @@
"use client";
import React, { useEffect, useRef } from "react";
import React, { useEffect, useRef, useState } from "react";
import { useQueryClient } from "@tanstack/react-query";
import { WEBSOCKET_URL } from "./apiClient";
import { useAuth } from "./AuthProvider";
import { z } from "zod";
import { invalidateTranscriptLists, TRANSCRIPT_SEARCH_URL } from "./apiHooks";
import {
invalidateTranscript,
invalidateTranscriptLists,
TRANSCRIPT_SEARCH_URL,
} from "./apiHooks";
import type { NonEmptyString } from "./utils";
import type { DagTask } from "./dagTypes";
export type { DagTask, DagTaskStatus } from "./dagTypes";
const DagStatusContext = React.createContext<Map<string, DagTask[]>>(new Map());
export function useDagStatusMap() {
return React.useContext(DagStatusContext);
}
const UserEvent = z.object({
event: z.string(),
@@ -95,6 +109,9 @@ export function UserEventsProvider({
const queryClient = useQueryClient();
const tokenRef = useRef<string | null>(null);
const detachRef = useRef<(() => void) | null>(null);
const [dagStatusMap, setDagStatusMap] = useState<Map<string, DagTask[]>>(
new Map(),
);
useEffect(() => {
// Only tear down when the user is truly unauthenticated
@@ -133,20 +150,52 @@ export function UserEventsProvider({
if (!detachRef.current) {
const onMessage = (event: MessageEvent) => {
try {
const msg = UserEvent.parse(JSON.parse(event.data));
const fullMsg = JSON.parse(event.data);
const msg = UserEvent.parse(fullMsg);
const eventName = msg.event;
const invalidateList = () => invalidateTranscriptLists(queryClient);
switch (eventName) {
case "TRANSCRIPT_CREATED":
case "TRANSCRIPT_DELETED":
case "TRANSCRIPT_STATUS":
case "TRANSCRIPT_FINAL_TITLE":
case "TRANSCRIPT_DURATION":
invalidateList().then(() => {});
break;
case "TRANSCRIPT_STATUS": {
invalidateList().then(() => {});
const transcriptId = fullMsg.data?.id as string | undefined;
if (transcriptId) {
invalidateTranscript(
queryClient,
transcriptId as NonEmptyString,
).then(() => {});
}
const status = fullMsg.data?.value as string | undefined;
if (transcriptId && status && status !== "processing") {
setDagStatusMap((prev) => {
const next = new Map(prev);
next.delete(transcriptId);
return next;
});
}
break;
}
case "TRANSCRIPT_DAG_STATUS": {
const transcriptId = fullMsg.data?.id as string | undefined;
const tasks = fullMsg.data?.tasks as DagTask[] | undefined;
if (transcriptId && tasks) {
setDagStatusMap((prev) => {
const next = new Map(prev);
next.set(transcriptId, tasks);
return next;
});
}
break;
}
default:
// Ignore other content events for list updates
break;
@@ -176,5 +225,9 @@ export function UserEventsProvider({
};
}, []);
return <>{children}</>;
return (
<DagStatusContext.Provider value={dagStatusMap}>
{children}
</DagStatusContext.Provider>
);
}

19
www/app/lib/dagTypes.ts Normal file
View File

@@ -0,0 +1,19 @@
export type DagTaskStatus =
| "queued"
| "running"
| "completed"
| "failed"
| "cancelled";
export type DagTask = {
name: string;
status: DagTaskStatus;
started_at: string | null;
finished_at: string | null;
duration_seconds: number | null;
parents: string[];
error: string | null;
children_total: number | null;
children_completed: number | null;
progress_pct: number | null;
};