mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-22 15:16:46 +00:00
Compare commits
1 Commits
feat/dag-p
...
feat/llm-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8373874cbd |
@@ -1,7 +1,6 @@
|
||||
"""Search functionality for transcripts and other entities."""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
@@ -173,9 +172,6 @@ class SearchResult(BaseModel):
|
||||
total_match_count: NonNegativeInt = Field(
|
||||
default=0, description="Total number of matches found in the transcript"
|
||||
)
|
||||
dag_status: list[dict] | None = Field(
|
||||
default=None, description="Latest DAG task status for processing transcripts"
|
||||
)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -332,42 +328,6 @@ class SnippetGenerator:
|
||||
return summary_snippets + webvtt_snippets, total_matches
|
||||
|
||||
|
||||
async def _fetch_dag_statuses(transcript_ids: list[str]) -> dict[str, list[dict]]:
|
||||
"""Fetch latest DAG_STATUS event data for given transcript IDs.
|
||||
|
||||
Returns dict mapping transcript_id -> tasks list from the last DAG_STATUS event.
|
||||
"""
|
||||
if not transcript_ids:
|
||||
return {}
|
||||
|
||||
db = get_database()
|
||||
query = sqlalchemy.select(
|
||||
[
|
||||
transcripts.c.id,
|
||||
transcripts.c.events,
|
||||
]
|
||||
).where(transcripts.c.id.in_(transcript_ids))
|
||||
|
||||
rows = await db.fetch_all(query)
|
||||
result: dict[str, list[dict]] = {}
|
||||
|
||||
for row in rows:
|
||||
events_raw = row["events"]
|
||||
if not events_raw:
|
||||
continue
|
||||
# events is stored as JSON list
|
||||
events = events_raw if isinstance(events_raw, list) else json.loads(events_raw)
|
||||
# Find last DAG_STATUS event
|
||||
for ev in reversed(events):
|
||||
if isinstance(ev, dict) and ev.get("event") == "DAG_STATUS":
|
||||
tasks = ev.get("data", {}).get("tasks")
|
||||
if tasks:
|
||||
result[row["id"]] = tasks
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
|
||||
@@ -510,14 +470,6 @@ class SearchController:
|
||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# Enrich processing transcripts with DAG status
|
||||
processing_ids = [r.id for r in results if r.status == "processing"]
|
||||
if processing_ids:
|
||||
dag_statuses = await _fetch_dag_statuses(processing_ids)
|
||||
for r in results:
|
||||
if r.id in dag_statuses:
|
||||
r.dag_status = dag_statuses[r.id]
|
||||
|
||||
return results, total
|
||||
|
||||
|
||||
|
||||
@@ -234,7 +234,7 @@ class Transcript(BaseModel):
|
||||
return dt.isoformat()
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump(mode="json"))
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from reflector.utils.string import NonEmptyString
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
|
||||
# Events that should also be sent to user room (matches Celery behavior)
|
||||
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION", "DAG_STATUS"}
|
||||
USER_ROOM_EVENTS = {"STATUS", "FINAL_TITLE", "DURATION"}
|
||||
|
||||
|
||||
async def broadcast_event(
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
"""
|
||||
DAG Progress Reporting — models and transform.
|
||||
|
||||
Converts Hatchet V1WorkflowRunDetails into structured DagTask list
|
||||
for frontend WebSocket/REST consumption.
|
||||
|
||||
Ported from render_hatchet_run.py (feat-dag-zulip) which renders markdown;
|
||||
this module produces structured Pydantic models instead.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from hatchet_sdk.clients.rest.models import (
|
||||
V1TaskStatus,
|
||||
V1WorkflowRunDetails,
|
||||
WorkflowRunShapeItemForWorkflowRunDetails,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DagTaskStatus(StrEnum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
_HATCHET_TO_DAG_STATUS: dict[V1TaskStatus, DagTaskStatus] = {
|
||||
V1TaskStatus.QUEUED: DagTaskStatus.QUEUED,
|
||||
V1TaskStatus.RUNNING: DagTaskStatus.RUNNING,
|
||||
V1TaskStatus.COMPLETED: DagTaskStatus.COMPLETED,
|
||||
V1TaskStatus.FAILED: DagTaskStatus.FAILED,
|
||||
V1TaskStatus.CANCELLED: DagTaskStatus.CANCELLED,
|
||||
}
|
||||
|
||||
|
||||
class DagTask(BaseModel):
|
||||
name: str
|
||||
status: DagTaskStatus
|
||||
started_at: datetime | None
|
||||
finished_at: datetime | None
|
||||
duration_seconds: float | None
|
||||
parents: list[str]
|
||||
error: str | None
|
||||
children_total: int | None
|
||||
children_completed: int | None
|
||||
progress_pct: float | None
|
||||
|
||||
|
||||
class DagStatusData(BaseModel):
|
||||
workflow_run_id: str
|
||||
tasks: list[DagTask]
|
||||
|
||||
|
||||
def _topo_sort(
|
||||
shape: list[WorkflowRunShapeItemForWorkflowRunDetails],
|
||||
) -> list[str]:
|
||||
"""Topological sort of step_ids from shape DAG (Kahn's algorithm).
|
||||
|
||||
Ported from render_hatchet_run.py.
|
||||
"""
|
||||
step_ids = {s.step_id for s in shape}
|
||||
children_map: dict[str, list[str]] = {}
|
||||
in_degree: dict[str, int] = {sid: 0 for sid in step_ids}
|
||||
|
||||
for s in shape:
|
||||
children = [c for c in (s.children_step_ids or []) if c in step_ids]
|
||||
children_map[s.step_id] = children
|
||||
for c in children:
|
||||
in_degree[c] += 1
|
||||
|
||||
queue = sorted(sid for sid, deg in in_degree.items() if deg == 0)
|
||||
result: list[str] = []
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
result.append(node)
|
||||
for c in children_map.get(node, []):
|
||||
in_degree[c] -= 1
|
||||
if in_degree[c] == 0:
|
||||
queue.append(c)
|
||||
queue.sort()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_error_summary(error_message: str | None) -> str | None:
|
||||
"""Extract first meaningful line from error message, skipping traceback frames."""
|
||||
if not error_message or not error_message.strip():
|
||||
return None
|
||||
|
||||
err_lines = error_message.strip().split("\n")
|
||||
err_summary = err_lines[0]
|
||||
for line in err_lines:
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith(("Traceback", "File ", "{", ")")):
|
||||
err_summary = stripped
|
||||
return err_summary
|
||||
|
||||
|
||||
def extract_dag_tasks(details: V1WorkflowRunDetails) -> list[DagTask]:
|
||||
"""Extract structured DagTask list from Hatchet workflow run details.
|
||||
|
||||
Returns tasks in topological order with status, timestamps, parents,
|
||||
error summaries, and fan-out children counts.
|
||||
"""
|
||||
shape = details.shape or []
|
||||
tasks = details.tasks or []
|
||||
|
||||
if not shape:
|
||||
return []
|
||||
|
||||
# Build lookups
|
||||
step_to_shape: dict[str, WorkflowRunShapeItemForWorkflowRunDetails] = {
|
||||
s.step_id: s for s in shape
|
||||
}
|
||||
step_to_name: dict[str, str] = {s.step_id: s.task_name for s in shape}
|
||||
|
||||
# Reverse edges: child -> parent names
|
||||
parents_by_step: dict[str, list[str]] = {s.step_id: [] for s in shape}
|
||||
for s in shape:
|
||||
for child_id in s.children_step_ids or []:
|
||||
if child_id in parents_by_step:
|
||||
parents_by_step[child_id].append(step_to_name[s.step_id])
|
||||
|
||||
# Join tasks by step_id
|
||||
from hatchet_sdk.clients.rest.models import V1TaskSummary # noqa: PLC0415
|
||||
|
||||
task_by_step: dict[str, V1TaskSummary] = {}
|
||||
for t in tasks:
|
||||
if t.step_id and t.step_id in step_to_name:
|
||||
task_by_step[t.step_id] = t
|
||||
|
||||
ordered = _topo_sort(shape)
|
||||
|
||||
result: list[DagTask] = []
|
||||
for step_id in ordered:
|
||||
name = step_to_name[step_id]
|
||||
t = task_by_step.get(step_id)
|
||||
|
||||
if not t:
|
||||
result.append(
|
||||
DagTask(
|
||||
name=name,
|
||||
status=DagTaskStatus.QUEUED,
|
||||
started_at=None,
|
||||
finished_at=None,
|
||||
duration_seconds=None,
|
||||
parents=parents_by_step.get(step_id, []),
|
||||
error=None,
|
||||
children_total=None,
|
||||
children_completed=None,
|
||||
progress_pct=None,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
status = _HATCHET_TO_DAG_STATUS.get(t.status, DagTaskStatus.QUEUED)
|
||||
|
||||
duration_seconds: float | None = None
|
||||
if t.duration is not None:
|
||||
duration_seconds = t.duration / 1000.0
|
||||
|
||||
# Fan-out children
|
||||
children_total: int | None = None
|
||||
children_completed: int | None = None
|
||||
if t.num_spawned_children and t.num_spawned_children > 0:
|
||||
children_total = t.num_spawned_children
|
||||
children_completed = sum(
|
||||
1 for c in (t.children or []) if c.status == V1TaskStatus.COMPLETED
|
||||
)
|
||||
|
||||
result.append(
|
||||
DagTask(
|
||||
name=name,
|
||||
status=status,
|
||||
started_at=t.started_at,
|
||||
finished_at=t.finished_at,
|
||||
duration_seconds=duration_seconds,
|
||||
parents=parents_by_step.get(step_id, []),
|
||||
error=_extract_error_summary(t.error_message),
|
||||
children_total=children_total,
|
||||
children_completed=children_completed,
|
||||
progress_pct=None,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def broadcast_dag_status(transcript_id: str, workflow_run_id: str) -> None:
|
||||
"""Fetch current DAG state from Hatchet and broadcast via WebSocket.
|
||||
|
||||
Fire-and-forget: exceptions are logged but never raised.
|
||||
All imports are deferred for fork-safety (Hatchet workers fork processes).
|
||||
"""
|
||||
try:
|
||||
from reflector.db.transcripts import transcripts_controller # noqa: I001, PLC0415
|
||||
from reflector.hatchet.broadcast import append_event_and_broadcast # noqa: PLC0415
|
||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import ( # noqa: PLC0415
|
||||
fresh_db_connection,
|
||||
)
|
||||
from reflector.logger import logger # noqa: PLC0415
|
||||
|
||||
async with fresh_db_connection():
|
||||
client = HatchetClientManager.get_client()
|
||||
details = await client.runs.aio_get(workflow_run_id)
|
||||
dag_tasks = extract_dag_tasks(details)
|
||||
dag_status = DagStatusData(workflow_run_id=workflow_run_id, tasks=dag_tasks)
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if transcript:
|
||||
await append_event_and_broadcast(
|
||||
transcript_id,
|
||||
transcript,
|
||||
"DAG_STATUS",
|
||||
dag_status,
|
||||
logger,
|
||||
)
|
||||
except Exception:
|
||||
from reflector.logger import logger # noqa: PLC0415
|
||||
|
||||
logger.warning(
|
||||
"[DAG Progress] Failed to broadcast DAG status",
|
||||
transcript_id=transcript_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -184,10 +184,7 @@ class Loggable(Protocol):
|
||||
|
||||
|
||||
def make_audio_progress_logger(
|
||||
ctx: Loggable,
|
||||
task_name: TaskName,
|
||||
interval: float = 5.0,
|
||||
transcript_id: str | None = None,
|
||||
ctx: Loggable, task_name: TaskName, interval: float = 5.0
|
||||
) -> Callable[[float | None, float], None]:
|
||||
"""Create a throttled progress logger callback for audio processing.
|
||||
|
||||
@@ -195,7 +192,6 @@ def make_audio_progress_logger(
|
||||
ctx: Object with .log() method (e.g., Hatchet Context).
|
||||
task_name: Name to prefix in log messages.
|
||||
interval: Minimum seconds between log messages.
|
||||
transcript_id: If provided, broadcasts transient DAG_TASK_PROGRESS events.
|
||||
|
||||
Returns:
|
||||
Callback(progress_pct, audio_position) that logs at most every `interval` seconds.
|
||||
@@ -217,27 +213,6 @@ def make_audio_progress_logger(
|
||||
)
|
||||
last_log_time[0] = now
|
||||
|
||||
if transcript_id and progress_pct is not None:
|
||||
try:
|
||||
import asyncio # noqa: PLC0415
|
||||
|
||||
from reflector.db.transcripts import TranscriptEvent # noqa: PLC0415
|
||||
from reflector.hatchet.broadcast import broadcast_event # noqa: PLC0415
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(
|
||||
broadcast_event(
|
||||
transcript_id,
|
||||
TranscriptEvent(
|
||||
event="DAG_TASK_PROGRESS",
|
||||
data={"task_name": task_name, "progress_pct": progress_pct},
|
||||
),
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass # transient, never fail the callback
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
@@ -262,15 +237,8 @@ def with_error_handling(
|
||||
) -> Callable[[PipelineInput, Context], Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(input: PipelineInput, ctx: Context) -> R:
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
|
||||
|
||||
try:
|
||||
result = await func(input, ctx)
|
||||
try:
|
||||
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
return await func(input, ctx)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Hatchet] {step_name} failed",
|
||||
@@ -278,10 +246,6 @@ def with_error_handling(
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await broadcast_dag_status(input.transcript_id, ctx.workflow_run_id)
|
||||
except Exception:
|
||||
pass
|
||||
if set_error_status:
|
||||
await set_workflow_error_status(input.transcript_id)
|
||||
raise
|
||||
@@ -596,9 +560,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||
target_sample_rate,
|
||||
offsets_seconds=None,
|
||||
logger=logger,
|
||||
progress_callback=make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, transcript_id=input.transcript_id
|
||||
),
|
||||
progress_callback=make_audio_progress_logger(ctx, TaskName.MIXDOWN_TRACKS),
|
||||
expected_duration_sec=recording_duration if recording_duration > 0 else None,
|
||||
)
|
||||
await writer.flush()
|
||||
|
||||
@@ -206,6 +206,12 @@ class LLM:
|
||||
"""Configure llamaindex Settings with OpenAILike LLM"""
|
||||
session_id = llm_session_id.get() or f"fallback-{uuid4().hex}"
|
||||
|
||||
extra_body: dict = {"litellm_session_id": session_id}
|
||||
# Only send enable_thinking when explicitly set (not None/unset).
|
||||
# Models that don't support it will ignore the param.
|
||||
if self.settings_obj.LLM_ENABLE_THINKING is not None:
|
||||
extra_body["enable_thinking"] = self.settings_obj.LLM_ENABLE_THINKING
|
||||
|
||||
Settings.llm = OpenAILike(
|
||||
model=self.model_name,
|
||||
api_base=self.url,
|
||||
@@ -215,7 +221,7 @@ class LLM:
|
||||
is_function_calling_model=False,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
additional_kwargs={"extra_body": {"litellm_session_id": session_id}},
|
||||
additional_kwargs={"extra_body": extra_body},
|
||||
)
|
||||
|
||||
async def get_response(
|
||||
|
||||
@@ -267,19 +267,6 @@ async def dispatch_transcript_processing(
|
||||
)
|
||||
|
||||
logger.info("Hatchet workflow dispatched", workflow_id=workflow_id)
|
||||
|
||||
try:
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status # noqa: I001, PLC0415
|
||||
|
||||
await broadcast_dag_status(config.transcript_id, workflow_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[DAG Progress] Failed initial broadcast",
|
||||
transcript_id=config.transcript_id,
|
||||
workflow_id=workflow_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
elif isinstance(config, FileProcessingConfig):
|
||||
|
||||
@@ -75,6 +75,7 @@ class Settings(BaseSettings):
|
||||
LLM_URL: str | None = None
|
||||
LLM_API_KEY: str | None = None
|
||||
LLM_CONTEXT_WINDOW: int = 16000
|
||||
LLM_ENABLE_THINKING: bool | None = None
|
||||
|
||||
LLM_PARSE_MAX_RETRIES: int = (
|
||||
3 # Max retries for JSON/validation errors (total attempts = retries + 1)
|
||||
|
||||
@@ -111,7 +111,6 @@ class GetTranscriptMinimal(BaseModel):
|
||||
room_id: str | None = None
|
||||
room_name: str | None = None
|
||||
audio_deleted: bool | None = None
|
||||
dag_status: list[dict] | None = None
|
||||
|
||||
|
||||
class TranscriptParticipantWithEmail(TranscriptParticipant):
|
||||
@@ -492,13 +491,6 @@ async def transcript_get(
|
||||
)
|
||||
)
|
||||
|
||||
dag_status = None
|
||||
if transcript.status == "processing" and transcript.events:
|
||||
for ev in reversed(transcript.events):
|
||||
if ev.event == "DAG_STATUS":
|
||||
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
|
||||
break
|
||||
|
||||
base_data = {
|
||||
"id": transcript.id,
|
||||
"user_id": transcript.user_id,
|
||||
@@ -520,7 +512,6 @@ async def transcript_get(
|
||||
"room_id": transcript.room_id,
|
||||
"room_name": room_name,
|
||||
"audio_deleted": transcript.audio_deleted,
|
||||
"dag_status": dag_status,
|
||||
"participants": participants,
|
||||
}
|
||||
|
||||
|
||||
@@ -41,19 +41,13 @@ async def transcript_events_websocket(
|
||||
|
||||
try:
|
||||
# on first connection, send all events only to the current user
|
||||
# Find the last DAG_STATUS to send after other historical events
|
||||
last_dag_status = None
|
||||
for event in transcript.events:
|
||||
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
||||
# not necessary to be sent to the client; but keep the rest
|
||||
name = event.event
|
||||
if name in ("TRANSCRIPT", "STATUS"):
|
||||
continue
|
||||
if name == "DAG_STATUS":
|
||||
last_dag_status = event
|
||||
continue
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
# Send only the most recent DAG_STATUS so reconnecting clients get current state
|
||||
if last_dag_status is not None:
|
||||
await websocket.send_json(last_dag_status.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
|
||||
@@ -1,959 +0,0 @@
|
||||
"""Tests for DAG progress models and transform function.
|
||||
|
||||
Tests the extract_dag_tasks function that converts Hatchet V1WorkflowRunDetails
|
||||
into structured DagTask list for frontend consumption.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.hatchet.constants import TaskName
|
||||
from reflector.hatchet.dag_progress import (
|
||||
DagStatusData,
|
||||
DagTask,
|
||||
DagTaskStatus,
|
||||
extract_dag_tasks,
|
||||
)
|
||||
|
||||
|
||||
def _make_shape_item(
|
||||
step_id: str,
|
||||
task_name: str,
|
||||
children_step_ids: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRunShapeItemForWorkflowRunDetails."""
|
||||
item = MagicMock()
|
||||
item.step_id = step_id
|
||||
item.task_name = task_name
|
||||
item.children_step_ids = children_step_ids or []
|
||||
return item
|
||||
|
||||
|
||||
def _make_task_summary(
|
||||
step_id: str,
|
||||
status: str = "QUEUED",
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
duration: int | None = None,
|
||||
error_message: str | None = None,
|
||||
task_external_id: str | None = None,
|
||||
num_spawned_children: int | None = None,
|
||||
children: list | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock V1TaskSummary."""
|
||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||
|
||||
task = MagicMock()
|
||||
task.step_id = step_id
|
||||
task.status = V1TaskStatus(status)
|
||||
task.started_at = started_at
|
||||
task.finished_at = finished_at
|
||||
task.duration = duration
|
||||
task.error_message = error_message
|
||||
task.task_external_id = task_external_id or f"ext-{step_id}"
|
||||
task.num_spawned_children = num_spawned_children
|
||||
task.children = children or []
|
||||
return task
|
||||
|
||||
|
||||
def _make_details(
|
||||
shape: list,
|
||||
tasks: list,
|
||||
run_id: str = "test-run-id",
|
||||
) -> MagicMock:
|
||||
"""Create a mock V1WorkflowRunDetails."""
|
||||
details = MagicMock()
|
||||
details.shape = shape
|
||||
details.tasks = tasks
|
||||
details.task_events = []
|
||||
details.run = MagicMock()
|
||||
details.run.metadata = MagicMock()
|
||||
details.run.metadata.id = run_id
|
||||
return details
|
||||
|
||||
|
||||
class TestExtractDagTasksBasic:
|
||||
"""Test basic extraction of DAG tasks from workflow run details."""
|
||||
|
||||
def test_empty_shape_returns_empty_list(self):
|
||||
details = _make_details(shape=[], tasks=[])
|
||||
result = extract_dag_tasks(details)
|
||||
assert result == []
|
||||
|
||||
def test_single_task_queued(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [_make_task_summary("s1", status="QUEUED")]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "get_recording"
|
||||
assert result[0].status == DagTaskStatus.QUEUED
|
||||
assert result[0].parents == []
|
||||
assert result[0].started_at is None
|
||||
assert result[0].finished_at is None
|
||||
assert result[0].duration_seconds is None
|
||||
assert result[0].error is None
|
||||
assert result[0].children_total is None
|
||||
assert result[0].children_completed is None
|
||||
assert result[0].progress_pct is None
|
||||
|
||||
def test_completed_task_with_duration(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="COMPLETED",
|
||||
started_at=now,
|
||||
finished_at=now,
|
||||
duration=1500, # milliseconds
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.COMPLETED
|
||||
assert result[0].duration_seconds == 1.5
|
||||
assert result[0].started_at == now
|
||||
assert result[0].finished_at == now
|
||||
|
||||
def test_failed_task_with_error(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="FAILED",
|
||||
error_message="Traceback (most recent call last):\n File something\nConnectionError: connection refused",
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.FAILED
|
||||
assert result[0].error == "ConnectionError: connection refused"
|
||||
|
||||
def test_running_task(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="RUNNING",
|
||||
started_at=now,
|
||||
duration=5000,
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.RUNNING
|
||||
assert result[0].started_at == now
|
||||
assert result[0].duration_seconds == 5.0
|
||||
|
||||
def test_cancelled_task(self):
|
||||
shape = [_make_shape_item("s1", "post_zulip")]
|
||||
tasks = [_make_task_summary("s1", status="CANCELLED")]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].status == DagTaskStatus.CANCELLED
|
||||
|
||||
|
||||
class TestExtractDagTasksTopology:
|
||||
"""Test topological ordering and parent extraction."""
|
||||
|
||||
def test_linear_chain_parents(self):
|
||||
"""A -> B -> C should produce correct parents."""
|
||||
shape = [
|
||||
_make_shape_item("s1", "get_recording", children_step_ids=["s2"]),
|
||||
_make_shape_item("s2", "get_participants", children_step_ids=["s3"]),
|
||||
_make_shape_item("s3", "process_tracks"),
|
||||
]
|
||||
tasks = [
|
||||
_make_task_summary("s1", status="COMPLETED"),
|
||||
_make_task_summary("s2", status="COMPLETED"),
|
||||
_make_task_summary("s3", status="QUEUED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert [t.name for t in result] == [
|
||||
"get_recording",
|
||||
"get_participants",
|
||||
"process_tracks",
|
||||
]
|
||||
assert result[0].parents == []
|
||||
assert result[1].parents == ["get_recording"]
|
||||
assert result[2].parents == ["get_participants"]
|
||||
|
||||
def test_diamond_dag(self):
|
||||
"""
|
||||
A -> B, A -> C, B -> D, C -> D
|
||||
D should have parents [B, C] (or [C, B] depending on sort).
|
||||
"""
|
||||
shape = [
|
||||
_make_shape_item("s1", "get_recording", children_step_ids=["s2", "s3"]),
|
||||
_make_shape_item("s2", "mixdown_tracks", children_step_ids=["s4"]),
|
||||
_make_shape_item("s3", "detect_topics", children_step_ids=["s4"]),
|
||||
_make_shape_item("s4", "finalize"),
|
||||
]
|
||||
tasks = [
|
||||
_make_task_summary("s1", status="COMPLETED"),
|
||||
_make_task_summary("s2", status="RUNNING"),
|
||||
_make_task_summary("s3", status="RUNNING"),
|
||||
_make_task_summary("s4", status="QUEUED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
# Topological: s1 first, s2/s3 in some order, s4 last
|
||||
assert result[0].name == "get_recording"
|
||||
assert result[-1].name == "finalize"
|
||||
finalize = result[-1]
|
||||
assert set(finalize.parents) == {"mixdown_tracks", "detect_topics"}
|
||||
|
||||
def test_topological_order_is_stable(self):
|
||||
"""Verify deterministic ordering (sorted queue in Kahn's)."""
|
||||
shape = [
|
||||
_make_shape_item("s_c", "task_c"),
|
||||
_make_shape_item("s_a", "task_a", children_step_ids=["s_c"]),
|
||||
_make_shape_item("s_b", "task_b", children_step_ids=["s_c"]),
|
||||
]
|
||||
tasks = [
|
||||
_make_task_summary("s_c", status="QUEUED"),
|
||||
_make_task_summary("s_a", status="COMPLETED"),
|
||||
_make_task_summary("s_b", status="COMPLETED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
# s_a and s_b both roots with in-degree 0; sorted alphabetically by step_id
|
||||
names = [t.name for t in result]
|
||||
assert names[-1] == "task_c"
|
||||
# First two should be task_a, task_b (sorted by step_id: s_a < s_b)
|
||||
assert names[0] == "task_a"
|
||||
assert names[1] == "task_b"
|
||||
|
||||
def test_production_dag_shape(self):
|
||||
"""Test the real 15-task pipeline topology with mixed statuses.
|
||||
|
||||
Simulates a mid-pipeline state where early tasks completed,
|
||||
middle tasks running, and later tasks still queued.
|
||||
"""
|
||||
# Production DAG edges (parent -> children):
|
||||
# get_recording -> get_participants
|
||||
# get_participants -> process_tracks
|
||||
# process_tracks -> mixdown_tracks, detect_topics, finalize
|
||||
# mixdown_tracks -> generate_waveform
|
||||
# detect_topics -> generate_title, extract_subjects
|
||||
# extract_subjects -> process_subjects, identify_action_items
|
||||
# process_subjects -> generate_recap
|
||||
# generate_title -> finalize
|
||||
# generate_recap -> finalize
|
||||
# identify_action_items -> finalize
|
||||
# finalize -> cleanup_consent
|
||||
# cleanup_consent -> post_zulip, send_webhook
|
||||
shape = [
|
||||
_make_shape_item(
|
||||
"s_get_recording", TaskName.GET_RECORDING, ["s_get_participants"]
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_get_participants", TaskName.GET_PARTICIPANTS, ["s_process_tracks"]
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_process_tracks",
|
||||
TaskName.PROCESS_TRACKS,
|
||||
["s_mixdown_tracks", "s_detect_topics", "s_finalize"],
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_mixdown_tracks", TaskName.MIXDOWN_TRACKS, ["s_generate_waveform"]
|
||||
),
|
||||
_make_shape_item("s_generate_waveform", TaskName.GENERATE_WAVEFORM),
|
||||
_make_shape_item(
|
||||
"s_detect_topics",
|
||||
TaskName.DETECT_TOPICS,
|
||||
["s_generate_title", "s_extract_subjects"],
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_generate_title", TaskName.GENERATE_TITLE, ["s_finalize"]
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_extract_subjects",
|
||||
TaskName.EXTRACT_SUBJECTS,
|
||||
["s_process_subjects", "s_identify_action_items"],
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_process_subjects", TaskName.PROCESS_SUBJECTS, ["s_generate_recap"]
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_generate_recap", TaskName.GENERATE_RECAP, ["s_finalize"]
|
||||
),
|
||||
_make_shape_item(
|
||||
"s_identify_action_items",
|
||||
TaskName.IDENTIFY_ACTION_ITEMS,
|
||||
["s_finalize"],
|
||||
),
|
||||
_make_shape_item("s_finalize", TaskName.FINALIZE, ["s_cleanup_consent"]),
|
||||
_make_shape_item(
|
||||
"s_cleanup_consent",
|
||||
TaskName.CLEANUP_CONSENT,
|
||||
["s_post_zulip", "s_send_webhook"],
|
||||
),
|
||||
_make_shape_item("s_post_zulip", TaskName.POST_ZULIP),
|
||||
_make_shape_item("s_send_webhook", TaskName.SEND_WEBHOOK),
|
||||
]
|
||||
|
||||
# Mid-pipeline: early tasks done, middle running, later queued
|
||||
tasks = [
|
||||
_make_task_summary("s_get_recording", status="COMPLETED"),
|
||||
_make_task_summary("s_get_participants", status="COMPLETED"),
|
||||
_make_task_summary("s_process_tracks", status="COMPLETED"),
|
||||
_make_task_summary("s_mixdown_tracks", status="RUNNING"),
|
||||
_make_task_summary("s_generate_waveform", status="QUEUED"),
|
||||
_make_task_summary("s_detect_topics", status="RUNNING"),
|
||||
_make_task_summary("s_generate_title", status="QUEUED"),
|
||||
_make_task_summary("s_extract_subjects", status="QUEUED"),
|
||||
_make_task_summary("s_process_subjects", status="QUEUED"),
|
||||
_make_task_summary("s_generate_recap", status="QUEUED"),
|
||||
_make_task_summary("s_identify_action_items", status="QUEUED"),
|
||||
_make_task_summary("s_finalize", status="QUEUED"),
|
||||
_make_task_summary("s_cleanup_consent", status="QUEUED"),
|
||||
_make_task_summary("s_post_zulip", status="QUEUED"),
|
||||
_make_task_summary("s_send_webhook", status="QUEUED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
# All 15 tasks present
|
||||
assert len(result) == 15
|
||||
result_names = [t.name for t in result]
|
||||
assert set(result_names) == {
|
||||
TaskName.GET_RECORDING,
|
||||
TaskName.GET_PARTICIPANTS,
|
||||
TaskName.PROCESS_TRACKS,
|
||||
TaskName.MIXDOWN_TRACKS,
|
||||
TaskName.GENERATE_WAVEFORM,
|
||||
TaskName.DETECT_TOPICS,
|
||||
TaskName.GENERATE_TITLE,
|
||||
TaskName.EXTRACT_SUBJECTS,
|
||||
TaskName.PROCESS_SUBJECTS,
|
||||
TaskName.GENERATE_RECAP,
|
||||
TaskName.IDENTIFY_ACTION_ITEMS,
|
||||
TaskName.FINALIZE,
|
||||
TaskName.CLEANUP_CONSENT,
|
||||
TaskName.POST_ZULIP,
|
||||
TaskName.SEND_WEBHOOK,
|
||||
}
|
||||
|
||||
# Topological order invariant: no task appears before its parents
|
||||
name_to_index = {t.name: i for i, t in enumerate(result)}
|
||||
for task in result:
|
||||
for parent_name in task.parents:
|
||||
assert name_to_index[parent_name] < name_to_index[task.name], (
|
||||
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
|
||||
f"must appear before {task.name} (idx {name_to_index[task.name]})"
|
||||
)
|
||||
|
||||
# finalize has exactly 4 parents
|
||||
finalize = next(t for t in result if t.name == TaskName.FINALIZE)
|
||||
assert set(finalize.parents) == {
|
||||
TaskName.PROCESS_TRACKS,
|
||||
TaskName.GENERATE_TITLE,
|
||||
TaskName.GENERATE_RECAP,
|
||||
TaskName.IDENTIFY_ACTION_ITEMS,
|
||||
}
|
||||
|
||||
# cleanup_consent has 1 parent (finalize)
|
||||
cleanup = next(t for t in result if t.name == TaskName.CLEANUP_CONSENT)
|
||||
assert cleanup.parents == [TaskName.FINALIZE]
|
||||
|
||||
# post_zulip and send_webhook both have cleanup_consent as parent
|
||||
post_zulip = next(t for t in result if t.name == TaskName.POST_ZULIP)
|
||||
send_webhook = next(t for t in result if t.name == TaskName.SEND_WEBHOOK)
|
||||
assert post_zulip.parents == [TaskName.CLEANUP_CONSENT]
|
||||
assert send_webhook.parents == [TaskName.CLEANUP_CONSENT]
|
||||
|
||||
# Verify statuses propagated correctly
|
||||
assert (
|
||||
next(t for t in result if t.name == TaskName.GET_RECORDING).status
|
||||
== DagTaskStatus.COMPLETED
|
||||
)
|
||||
assert (
|
||||
next(t for t in result if t.name == TaskName.MIXDOWN_TRACKS).status
|
||||
== DagTaskStatus.RUNNING
|
||||
)
|
||||
assert (
|
||||
next(t for t in result if t.name == TaskName.FINALIZE).status
|
||||
== DagTaskStatus.QUEUED
|
||||
)
|
||||
|
||||
def test_topological_sort_invariant_complex_dag(self):
|
||||
"""For a complex DAG, every task's parents appear earlier in the list.
|
||||
|
||||
Uses a wider branching/merging DAG than diamond to stress the invariant.
|
||||
"""
|
||||
# DAG: A -> B, A -> C, A -> D, B -> E, C -> E, C -> F, D -> F, E -> G, F -> G
|
||||
shape = [
|
||||
_make_shape_item("s_a", "task_a", ["s_b", "s_c", "s_d"]),
|
||||
_make_shape_item("s_b", "task_b", ["s_e"]),
|
||||
_make_shape_item("s_c", "task_c", ["s_e", "s_f"]),
|
||||
_make_shape_item("s_d", "task_d", ["s_f"]),
|
||||
_make_shape_item("s_e", "task_e", ["s_g"]),
|
||||
_make_shape_item("s_f", "task_f", ["s_g"]),
|
||||
_make_shape_item("s_g", "task_g"),
|
||||
]
|
||||
tasks = [
|
||||
_make_task_summary("s_a", status="COMPLETED"),
|
||||
_make_task_summary("s_b", status="COMPLETED"),
|
||||
_make_task_summary("s_c", status="RUNNING"),
|
||||
_make_task_summary("s_d", status="COMPLETED"),
|
||||
_make_task_summary("s_e", status="QUEUED"),
|
||||
_make_task_summary("s_f", status="QUEUED"),
|
||||
_make_task_summary("s_g", status="QUEUED"),
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert len(result) == 7
|
||||
name_to_index = {t.name: i for i, t in enumerate(result)}
|
||||
|
||||
# Verify invariant: every parent appears before its child
|
||||
for task in result:
|
||||
for parent_name in task.parents:
|
||||
assert name_to_index[parent_name] < name_to_index[task.name], (
|
||||
f"Parent {parent_name} (idx {name_to_index[parent_name]}) "
|
||||
f"must appear before {task.name} (idx {name_to_index[task.name]})"
|
||||
)
|
||||
|
||||
# task_g has 2 parents
|
||||
task_g = next(t for t in result if t.name == "task_g")
|
||||
assert set(task_g.parents) == {"task_e", "task_f"}
|
||||
|
||||
# task_e has 2 parents
|
||||
task_e = next(t for t in result if t.name == "task_e")
|
||||
assert set(task_e.parents) == {"task_b", "task_c"}
|
||||
|
||||
# task_a is root (first in topological order)
|
||||
assert result[0].name == "task_a"
|
||||
assert result[0].parents == []
|
||||
|
||||
|
||||
class TestExtractDagTasksFanOut:
|
||||
"""Test fan-out tasks with spawned children."""
|
||||
|
||||
def test_fan_out_children_counts(self):
|
||||
from hatchet_sdk.clients.rest.models import V1TaskStatus
|
||||
|
||||
child_mocks = []
|
||||
for status in ["COMPLETED", "COMPLETED", "RUNNING", "QUEUED"]:
|
||||
child = MagicMock()
|
||||
child.status = V1TaskStatus(status)
|
||||
child_mocks.append(child)
|
||||
|
||||
shape = [_make_shape_item("s1", "process_tracks")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1",
|
||||
status="RUNNING",
|
||||
num_spawned_children=4,
|
||||
children=child_mocks,
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].children_total == 4
|
||||
assert result[0].children_completed == 2
|
||||
|
||||
def test_no_children_when_no_spawn(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [
|
||||
_make_task_summary("s1", status="COMPLETED", num_spawned_children=None)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].children_total is None
|
||||
assert result[0].children_completed is None
|
||||
|
||||
def test_zero_spawned_children(self):
|
||||
shape = [_make_shape_item("s1", "process_tracks")]
|
||||
tasks = [_make_task_summary("s1", status="COMPLETED", num_spawned_children=0)]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert result[0].children_total is None
|
||||
assert result[0].children_completed is None
|
||||
|
||||
|
||||
class TestExtractDagTasksErrorExtraction:
|
||||
"""Test error message extraction logic."""
|
||||
|
||||
def test_simple_error(self):
|
||||
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
||||
tasks = [
|
||||
_make_task_summary(
|
||||
"s1", status="FAILED", error_message="ValueError: no tracks"
|
||||
)
|
||||
]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error == "ValueError: no tracks"
|
||||
|
||||
def test_traceback_extracts_meaningful_line(self):
|
||||
error = (
|
||||
"Traceback (most recent call last):\n"
|
||||
' File "/app/something.py", line 42\n'
|
||||
"RuntimeError: out of memory"
|
||||
)
|
||||
shape = [_make_shape_item("s1", "mixdown_tracks")]
|
||||
tasks = [_make_task_summary("s1", status="FAILED", error_message=error)]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error == "RuntimeError: out of memory"
|
||||
|
||||
def test_no_error_when_none(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [_make_task_summary("s1", status="COMPLETED", error_message=None)]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error is None
|
||||
|
||||
def test_empty_error_message(self):
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [_make_task_summary("s1", status="FAILED", error_message="")]
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result[0].error is None
|
||||
|
||||
|
||||
class TestExtractDagTasksMissingData:
|
||||
"""Test edge cases with missing task data."""
|
||||
|
||||
def test_shape_without_matching_task(self):
|
||||
"""Shape has a step but tasks list doesn't contain it."""
|
||||
shape = [_make_shape_item("s1", "get_recording")]
|
||||
tasks = [] # No matching task
|
||||
details = _make_details(shape, tasks)
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "get_recording"
|
||||
assert result[0].status == DagTaskStatus.QUEUED # default when no task data
|
||||
assert result[0].started_at is None
|
||||
|
||||
def test_none_shape_returns_empty(self):
|
||||
details = _make_details(shape=[], tasks=[])
|
||||
details.shape = None
|
||||
|
||||
result = extract_dag_tasks(details)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestDagStatusData:
|
||||
"""Test DagStatusData model serialization."""
|
||||
|
||||
def test_serialization(self):
|
||||
task = DagTask(
|
||||
name="get_recording",
|
||||
status=DagTaskStatus.COMPLETED,
|
||||
started_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
finished_at=datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc),
|
||||
duration_seconds=1.0,
|
||||
parents=[],
|
||||
error=None,
|
||||
children_total=None,
|
||||
children_completed=None,
|
||||
progress_pct=None,
|
||||
)
|
||||
data = DagStatusData(workflow_run_id="test-123", tasks=[task])
|
||||
dumped = data.model_dump(mode="json")
|
||||
|
||||
assert dumped["workflow_run_id"] == "test-123"
|
||||
assert len(dumped["tasks"]) == 1
|
||||
assert dumped["tasks"][0]["name"] == "get_recording"
|
||||
assert dumped["tasks"][0]["status"] == "completed"
|
||||
assert dumped["tasks"][0]["duration_seconds"] == 1.0
|
||||
|
||||
|
||||
class AsyncContextManager:
|
||||
"""No-op async context manager for mocking fresh_db_connection."""
|
||||
|
||||
async def __aenter__(self):
|
||||
return None
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class TestBroadcastDagStatus:
|
||||
"""Test broadcast_dag_status function.
|
||||
|
||||
broadcast_dag_status uses deferred imports inside its function body.
|
||||
We mock the source modules/objects before calling the function.
|
||||
Importing daily_multitrack_pipeline triggers a cascade
|
||||
(subject_processing -> HatchetClientManager.get_client at module level),
|
||||
so we set _instance before the import to prevent real SDK init.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_hatchet_mock(self):
|
||||
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
|
||||
|
||||
Module-level code in workflow files calls get_client() during import.
|
||||
Setting _instance before import avoids ClientConfig validation.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
original = HatchetClientManager._instance
|
||||
HatchetClientManager._instance = MagicMock()
|
||||
yield
|
||||
HatchetClientManager._instance = original
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcasts_dag_status(self):
|
||||
"""broadcast_dag_status fetches run, transforms, and broadcasts."""
|
||||
mock_transcript = MagicMock()
|
||||
mock_transcript.id = "t-123"
|
||||
|
||||
mock_details = _make_details(
|
||||
shape=[_make_shape_item("s1", "get_recording")],
|
||||
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||
run_id="wf-abc",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast,
|
||||
patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_transcript,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
return_value=AsyncContextManager(),
|
||||
),
|
||||
):
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
mock_client.runs.aio_get.assert_called_once_with("wf-abc")
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert call_args[0][0] == "t-123" # transcript_id
|
||||
assert call_args[0][1] is mock_transcript # transcript
|
||||
assert call_args[0][2] == "DAG_STATUS" # event_name
|
||||
data = call_args[0][3]
|
||||
assert isinstance(data, DagStatusData)
|
||||
assert data.workflow_run_id == "wf-abc"
|
||||
assert len(data.tasks) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_exceptions(self):
|
||||
"""broadcast_dag_status never raises even when internals fail."""
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
side_effect=RuntimeError("db exploded"),
|
||||
):
|
||||
# Should not raise
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_broadcast_when_transcript_not_found(self):
|
||||
"""broadcast_dag_status does not broadcast if transcript is None."""
|
||||
mock_details = _make_details(
|
||||
shape=[_make_shape_item("s1", "get_recording")],
|
||||
tasks=[_make_task_summary("s1", status="COMPLETED")],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.runs.aio_get = AsyncMock(return_value=mock_details)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.fresh_db_connection",
|
||||
return_value=AsyncContextManager(),
|
||||
),
|
||||
patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast,
|
||||
):
|
||||
from reflector.hatchet.dag_progress import broadcast_dag_status
|
||||
|
||||
await broadcast_dag_status("t-123", "wf-abc")
|
||||
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
|
||||
class TestMakeAudioProgressLoggerWithBroadcast:
|
||||
"""Test make_audio_progress_logger with transcript_id for transient broadcasts."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_hatchet_mock(self):
|
||||
"""Set HatchetClientManager._instance to prevent real SDK init on import."""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
original = HatchetClientManager._instance
|
||||
if original is None:
|
||||
HatchetClientManager._instance = MagicMock()
|
||||
yield
|
||||
HatchetClientManager._instance = original
|
||||
|
||||
def test_broadcasts_transient_progress_event(self):
|
||||
"""When transcript_id provided and progress_pct not None, broadcasts event."""
|
||||
import asyncio
|
||||
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
mock_broadcast = AsyncMock()
|
||||
tasks_created = []
|
||||
|
||||
original_create_task = loop.create_task
|
||||
|
||||
def capture_create_task(coro):
|
||||
task = original_create_task(coro)
|
||||
tasks_created.append(task)
|
||||
return task
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
mock_broadcast,
|
||||
),
|
||||
patch.object(loop, "create_task", side_effect=capture_create_task),
|
||||
):
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||
)
|
||||
callback(50.0, 100.0)
|
||||
|
||||
# Run pending tasks
|
||||
if tasks_created:
|
||||
loop.run_until_complete(asyncio.gather(*tasks_created))
|
||||
|
||||
mock_broadcast.assert_called_once()
|
||||
event_arg = mock_broadcast.call_args[0][1]
|
||||
assert event_arg.event == "DAG_TASK_PROGRESS"
|
||||
assert event_arg.data["task_name"] == TaskName.MIXDOWN_TRACKS
|
||||
assert event_arg.data["progress_pct"] == 50.0
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_no_broadcast_without_transcript_id(self):
|
||||
"""When transcript_id is None, no broadcast happens."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast:
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id=None
|
||||
)
|
||||
callback(50.0, 100.0)
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
def test_no_broadcast_when_progress_pct_is_none(self):
|
||||
"""When progress_pct is None, no broadcast happens even with transcript_id."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast:
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||
)
|
||||
callback(None, 100.0)
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
def test_logging_throttled_by_interval(self):
|
||||
"""With interval=5.0, rapid calls only log once until interval elapses.
|
||||
|
||||
The throttle applies to ctx.log() calls. Broadcasts (fire-and-forget)
|
||||
are not throttled — they occur every call when transcript_id + progress_pct set.
|
||||
"""
|
||||
import asyncio
|
||||
import time as time_mod
|
||||
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
mock_broadcast = AsyncMock()
|
||||
tasks_created = []
|
||||
original_create_task = loop.create_task
|
||||
|
||||
def capture_create_task(coro):
|
||||
task = original_create_task(coro)
|
||||
tasks_created.append(task)
|
||||
return task
|
||||
|
||||
# Controlled monotonic values for the 4 calls from make_audio_progress_logger:
|
||||
# init (start_time, last_log_time), call1 (now), call2 (now), call3 (now)
|
||||
# After those, fall back to real time.monotonic() for asyncio internals.
|
||||
controlled_values = [100.0, 100.0, 101.0, 106.0]
|
||||
call_index = [0]
|
||||
real_monotonic = time_mod.monotonic
|
||||
|
||||
def mock_monotonic():
|
||||
if call_index[0] < len(controlled_values):
|
||||
val = controlled_values[call_index[0]]
|
||||
call_index[0] += 1
|
||||
return val
|
||||
return real_monotonic()
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.time.monotonic",
|
||||
side_effect=mock_monotonic,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
mock_broadcast,
|
||||
),
|
||||
patch.object(loop, "create_task", side_effect=capture_create_task),
|
||||
):
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=5.0, transcript_id="t-123"
|
||||
)
|
||||
|
||||
# Call 1 at t=100.0: 100.0 - 100.0 = 0.0 < 5.0 => no log
|
||||
callback(25.0, 50.0)
|
||||
assert ctx.log.call_count == 0
|
||||
|
||||
# Call 2 at t=101.0: 101.0 - 100.0 = 1.0 < 5.0 => no log
|
||||
callback(50.0, 100.0)
|
||||
assert ctx.log.call_count == 0
|
||||
|
||||
# Call 3 at t=106.0: 106.0 - 100.0 = 6.0 >= 5.0 => logs
|
||||
callback(75.0, 150.0)
|
||||
assert ctx.log.call_count == 1
|
||||
|
||||
# Run pending broadcast tasks
|
||||
if tasks_created:
|
||||
loop.run_until_complete(asyncio.gather(*tasks_created))
|
||||
|
||||
# Broadcasts happen on every call (not throttled) — 3 calls total
|
||||
assert mock_broadcast.call_count == 3
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_uses_broadcast_event_not_append_event_and_broadcast(self):
|
||||
"""Progress events use broadcast_event (transient), not append_event_and_broadcast (persisted)."""
|
||||
import asyncio
|
||||
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
make_audio_progress_logger,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
mock_broadcast_event = AsyncMock()
|
||||
mock_append = AsyncMock()
|
||||
tasks_created = []
|
||||
original_create_task = loop.create_task
|
||||
|
||||
def capture_create_task(coro):
|
||||
task = original_create_task(coro)
|
||||
tasks_created.append(task)
|
||||
return task
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.broadcast_event",
|
||||
mock_broadcast_event,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.broadcast.append_event_and_broadcast",
|
||||
mock_append,
|
||||
),
|
||||
patch.object(loop, "create_task", side_effect=capture_create_task),
|
||||
):
|
||||
callback = make_audio_progress_logger(
|
||||
ctx, TaskName.MIXDOWN_TRACKS, interval=0.0, transcript_id="t-123"
|
||||
)
|
||||
callback(50.0, 100.0)
|
||||
|
||||
if tasks_created:
|
||||
loop.run_until_complete(asyncio.gather(*tasks_created))
|
||||
|
||||
# broadcast_event (transient) IS called
|
||||
mock_broadcast_event.assert_called_once()
|
||||
# append_event_and_broadcast (persisted) is NOT called
|
||||
mock_append.assert_not_called()
|
||||
finally:
|
||||
loop.close()
|
||||
@@ -1,181 +0,0 @@
|
||||
"""Tests for with_error_handling decorator integration with broadcast_dag_status.
|
||||
|
||||
The decorator wraps each pipeline task and calls broadcast_dag_status on both
|
||||
success and failure paths. These tests verify that integration rather than
|
||||
testing broadcast_dag_status in isolation (which test_dag_progress.py covers).
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.hatchet.constants import TaskName
|
||||
|
||||
|
||||
class TestWithErrorHandlingBroadcast:
|
||||
"""Test with_error_handling decorator's integration with broadcast_dag_status."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_hatchet_mock(self):
|
||||
"""Set HatchetClientManager._instance to a mock to prevent real SDK init.
|
||||
|
||||
Module-level code in workflow files calls get_client() during import.
|
||||
Setting _instance before import avoids ClientConfig validation.
|
||||
"""
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
original = HatchetClientManager._instance
|
||||
HatchetClientManager._instance = MagicMock()
|
||||
yield
|
||||
HatchetClientManager._instance = original
|
||||
|
||||
def _make_input(self, transcript_id: str = "t-123") -> MagicMock:
|
||||
"""Create a mock PipelineInput with transcript_id."""
|
||||
inp = MagicMock()
|
||||
inp.transcript_id = transcript_id
|
||||
return inp
|
||||
|
||||
def _make_ctx(self, workflow_run_id: str = "wf-abc") -> MagicMock:
|
||||
"""Create a mock Context with workflow_run_id."""
|
||||
ctx = MagicMock()
|
||||
ctx.workflow_run_id = workflow_run_id
|
||||
return ctx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_broadcast_on_success(self):
|
||||
"""Decorator calls broadcast_dag_status once when task succeeds."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
inner = AsyncMock(return_value="ok")
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast:
|
||||
result = await wrapped(self._make_input(), self._make_ctx())
|
||||
|
||||
assert result == "ok"
|
||||
mock_broadcast.assert_called_once_with("t-123", "wf-abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_broadcast_on_failure(self):
|
||||
"""Decorator calls broadcast_dag_status once when task raises."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
inner = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_broadcast,
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await wrapped(self._make_input(), self._make_ctx())
|
||||
|
||||
mock_broadcast.assert_called_once_with("t-123", "wf-abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_broadcast_exception_on_success(self):
|
||||
"""Broadcast failure does not crash the task on the success path."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
inner = AsyncMock(return_value="ok")
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("broadcast exploded"),
|
||||
):
|
||||
result = await wrapped(self._make_input(), self._make_ctx())
|
||||
|
||||
assert result == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_broadcast_exception_on_failure(self):
|
||||
"""Original task exception propagates even when broadcast also fails."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
inner = AsyncMock(side_effect=ValueError("original error"))
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("broadcast exploded"),
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="original error"):
|
||||
await wrapped(self._make_input(), self._make_ctx())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_set_workflow_error_status_on_failure(self):
|
||||
"""On task failure with set_error_status=True (default), calls set_workflow_error_status."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
inner = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING)(inner)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error,
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await wrapped(self._make_input(), self._make_ctx())
|
||||
|
||||
mock_set_error.assert_called_once_with("t-123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_set_workflow_error_status_when_disabled(self):
|
||||
"""With set_error_status=False, set_workflow_error_status is NOT called on failure."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
inner = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
wrapped = with_error_handling(TaskName.GET_RECORDING, set_error_status=False)(
|
||||
inner
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.hatchet.dag_progress.broadcast_dag_status",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error,
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await wrapped(self._make_input(), self._make_ctx())
|
||||
|
||||
mock_set_error.assert_not_called()
|
||||
@@ -1,421 +0,0 @@
|
||||
"""Tests for DAG status REST enrichment on search and transcript GET endpoints."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import reflector.db.search as search_module
|
||||
from reflector.db.search import SearchResult, _fetch_dag_statuses
|
||||
from reflector.db.transcripts import TranscriptEvent
|
||||
|
||||
|
||||
class TestFetchDagStatuses:
|
||||
"""Test the _fetch_dag_statuses helper."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_for_empty_ids(self):
|
||||
result = await _fetch_dag_statuses([])
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_last_dag_status(self):
|
||||
events = [
|
||||
{"event": "STATUS", "data": {"value": "processing"}},
|
||||
{
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "get_recording", "status": "completed"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [
|
||||
{"name": "get_recording", "status": "completed"},
|
||||
{"name": "process_tracks", "status": "running"},
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
mock_row = {"id": "t1", "events": events}
|
||||
|
||||
with patch("reflector.db.search.get_database") as mock_db:
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
result = await _fetch_dag_statuses(["t1"])
|
||||
|
||||
assert "t1" in result
|
||||
assert len(result["t1"]) == 2 # Last DAG_STATUS had 2 tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_transcripts_without_events(self):
|
||||
mock_row = {"id": "t1", "events": None}
|
||||
|
||||
with patch("reflector.db.search.get_database") as mock_db:
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
result = await _fetch_dag_statuses(["t1"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_transcripts_without_dag_status(self):
|
||||
events = [
|
||||
{"event": "STATUS", "data": {"value": "processing"}},
|
||||
{"event": "DURATION", "data": {"duration": 1000}},
|
||||
]
|
||||
mock_row = {"id": "t1", "events": events}
|
||||
|
||||
with patch("reflector.db.search.get_database") as mock_db:
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
result = await _fetch_dag_statuses(["t1"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_json_string_events(self):
|
||||
"""Events stored as JSON string rather than already-parsed list."""
|
||||
import json
|
||||
|
||||
events = [
|
||||
{
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "transcribe", "status": "running"}],
|
||||
},
|
||||
},
|
||||
]
|
||||
mock_row = {"id": "t1", "events": json.dumps(events)}
|
||||
|
||||
with patch("reflector.db.search.get_database") as mock_db:
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
result = await _fetch_dag_statuses(["t1"])
|
||||
|
||||
assert "t1" in result
|
||||
assert len(result["t1"]) == 1
|
||||
assert result["t1"][0]["name"] == "transcribe"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_transcripts(self):
|
||||
"""Handles multiple transcripts in one call."""
|
||||
events_t1 = [
|
||||
{
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "a", "status": "completed"}],
|
||||
},
|
||||
},
|
||||
]
|
||||
events_t2 = [
|
||||
{
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "r2",
|
||||
"tasks": [{"name": "b", "status": "running"}],
|
||||
},
|
||||
},
|
||||
]
|
||||
mock_rows = [
|
||||
{"id": "t1", "events": events_t1},
|
||||
{"id": "t2", "events": events_t2},
|
||||
]
|
||||
|
||||
with patch("reflector.db.search.get_database") as mock_db:
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=mock_rows)
|
||||
result = await _fetch_dag_statuses(["t1", "t2"])
|
||||
|
||||
assert "t1" in result
|
||||
assert "t2" in result
|
||||
assert result["t1"][0]["name"] == "a"
|
||||
assert result["t2"][0]["name"] == "b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_status_without_tasks_key_skipped(self):
|
||||
"""DAG_STATUS event with no tasks key in data should be skipped."""
|
||||
events = [
|
||||
{"event": "DAG_STATUS", "data": {"workflow_run_id": "r1"}},
|
||||
]
|
||||
mock_row = {"id": "t1", "events": events}
|
||||
|
||||
with patch("reflector.db.search.get_database") as mock_db:
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
result = await _fetch_dag_statuses(["t1"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
def _extract_dag_status_from_transcript(transcript):
|
||||
"""Replicate the dag_status extraction logic from transcript_get view.
|
||||
|
||||
This mirrors the code in reflector/views/transcripts.py lines 495-500:
|
||||
dag_status = None
|
||||
if transcript.status == "processing" and transcript.events:
|
||||
for ev in reversed(transcript.events):
|
||||
if ev.event == "DAG_STATUS":
|
||||
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
|
||||
break
|
||||
"""
|
||||
dag_status = None
|
||||
if transcript.status == "processing" and transcript.events:
|
||||
for ev in reversed(transcript.events):
|
||||
if ev.event == "DAG_STATUS":
|
||||
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
|
||||
break
|
||||
return dag_status
|
||||
|
||||
|
||||
class TestTranscriptGetDagStatusExtraction:
|
||||
"""Test dag_status extraction logic from transcript_get endpoint.
|
||||
|
||||
The actual endpoint is complex to set up, so we test the extraction
|
||||
logic directly using the same code pattern from the view.
|
||||
"""
|
||||
|
||||
def test_processing_transcript_with_dag_status_events(self):
|
||||
"""Processing transcript with DAG_STATUS events returns tasks from last event."""
|
||||
transcript = SimpleNamespace(
|
||||
status="processing",
|
||||
events=[
|
||||
TranscriptEvent(event="STATUS", data={"value": "processing"}),
|
||||
TranscriptEvent(
|
||||
event="DAG_STATUS",
|
||||
data={
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "get_recording", "status": "completed"}],
|
||||
},
|
||||
),
|
||||
TranscriptEvent(
|
||||
event="DAG_STATUS",
|
||||
data={
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [
|
||||
{"name": "get_recording", "status": "completed"},
|
||||
{"name": "transcribe", "status": "running"},
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = _extract_dag_status_from_transcript(transcript)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "get_recording"
|
||||
assert result[1]["name"] == "transcribe"
|
||||
assert result[1]["status"] == "running"
|
||||
|
||||
def test_processing_transcript_without_dag_status_events(self):
|
||||
"""Processing transcript with only non-DAG_STATUS events returns None."""
|
||||
transcript = SimpleNamespace(
|
||||
status="processing",
|
||||
events=[
|
||||
TranscriptEvent(event="STATUS", data={"value": "processing"}),
|
||||
TranscriptEvent(event="DURATION", data={"duration": 1000}),
|
||||
],
|
||||
)
|
||||
|
||||
result = _extract_dag_status_from_transcript(transcript)
|
||||
assert result is None
|
||||
|
||||
def test_ended_transcript_with_dag_status_events(self):
|
||||
"""Ended transcript with DAG_STATUS events returns None (status check)."""
|
||||
transcript = SimpleNamespace(
|
||||
status="ended",
|
||||
events=[
|
||||
TranscriptEvent(
|
||||
event="DAG_STATUS",
|
||||
data={
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "transcribe", "status": "completed"}],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = _extract_dag_status_from_transcript(transcript)
|
||||
assert result is None
|
||||
|
||||
def test_processing_transcript_with_empty_events(self):
|
||||
"""Processing transcript with empty events list returns None."""
|
||||
transcript = SimpleNamespace(
|
||||
status="processing",
|
||||
events=[],
|
||||
)
|
||||
|
||||
result = _extract_dag_status_from_transcript(transcript)
|
||||
assert result is None
|
||||
|
||||
def test_processing_transcript_with_none_events(self):
|
||||
"""Processing transcript with None events returns None."""
|
||||
transcript = SimpleNamespace(
|
||||
status="processing",
|
||||
events=None,
|
||||
)
|
||||
|
||||
result = _extract_dag_status_from_transcript(transcript)
|
||||
assert result is None
|
||||
|
||||
def test_extracts_last_dag_status_not_first(self):
|
||||
"""Should pick the last DAG_STATUS event (most recent), not the first."""
|
||||
transcript = SimpleNamespace(
|
||||
status="processing",
|
||||
events=[
|
||||
TranscriptEvent(
|
||||
event="DAG_STATUS",
|
||||
data={
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "a", "status": "running"}],
|
||||
},
|
||||
),
|
||||
TranscriptEvent(event="STATUS", data={"value": "processing"}),
|
||||
TranscriptEvent(
|
||||
event="DAG_STATUS",
|
||||
data={
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [
|
||||
{"name": "a", "status": "completed"},
|
||||
{"name": "b", "status": "running"},
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = _extract_dag_status_from_transcript(transcript)
|
||||
assert len(result) == 2
|
||||
assert result[0]["status"] == "completed"
|
||||
assert result[1]["name"] == "b"
|
||||
|
||||
|
||||
class TestSearchEnrichmentIntegration:
|
||||
"""Test DAG status enrichment in search results.
|
||||
|
||||
The search function enriches processing transcripts with dag_status
|
||||
by calling _fetch_dag_statuses for processing IDs and assigning results.
|
||||
We test this enrichment logic by mocking _fetch_dag_statuses.
|
||||
"""
|
||||
|
||||
def _make_search_result(self, id: str, status: str) -> SearchResult:
|
||||
"""Create a minimal SearchResult for testing."""
|
||||
return SearchResult(
|
||||
id=id,
|
||||
title=f"Transcript {id}",
|
||||
user_id="u1",
|
||||
room_id=None,
|
||||
room_name=None,
|
||||
source_kind="live",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
||||
status=status,
|
||||
rank=1.0,
|
||||
duration=60.0,
|
||||
search_snippets=[],
|
||||
total_match_count=0,
|
||||
dag_status=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_result_gets_dag_status(self):
|
||||
"""SearchResult with status='processing' and matching DAG_STATUS events
|
||||
gets dag_status populated."""
|
||||
results = [self._make_search_result("t1", "processing")]
|
||||
dag_tasks = [
|
||||
{"name": "get_recording", "status": "completed"},
|
||||
{"name": "transcribe", "status": "running"},
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
search_module,
|
||||
"_fetch_dag_statuses",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"t1": dag_tasks},
|
||||
) as mock_fetch:
|
||||
# Replicate the enrichment logic from SearchController.search_transcripts
|
||||
processing_ids = [r.id for r in results if r.status == "processing"]
|
||||
if processing_ids:
|
||||
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
|
||||
for r in results:
|
||||
if r.id in dag_statuses:
|
||||
r.dag_status = dag_statuses[r.id]
|
||||
|
||||
mock_fetch.assert_called_once_with(["t1"])
|
||||
|
||||
assert results[0].dag_status == dag_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ended_result_does_not_trigger_fetch(self):
|
||||
"""SearchResult with status='ended' does NOT trigger _fetch_dag_statuses."""
|
||||
results = [self._make_search_result("t1", "ended")]
|
||||
|
||||
with patch.object(
|
||||
search_module,
|
||||
"_fetch_dag_statuses",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
) as mock_fetch:
|
||||
processing_ids = [r.id for r in results if r.status == "processing"]
|
||||
if processing_ids:
|
||||
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
|
||||
for r in results:
|
||||
if r.id in dag_statuses:
|
||||
r.dag_status = dag_statuses[r.id]
|
||||
|
||||
mock_fetch.assert_not_called()
|
||||
|
||||
assert results[0].dag_status is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_processing_and_ended_results(self):
|
||||
"""Only processing results get enriched; ended results stay None."""
|
||||
results = [
|
||||
self._make_search_result("t1", "processing"),
|
||||
self._make_search_result("t2", "ended"),
|
||||
self._make_search_result("t3", "processing"),
|
||||
]
|
||||
dag_tasks_t1 = [{"name": "transcribe", "status": "running"}]
|
||||
dag_tasks_t3 = [{"name": "diarize", "status": "completed"}]
|
||||
|
||||
with patch.object(
|
||||
search_module,
|
||||
"_fetch_dag_statuses",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"t1": dag_tasks_t1, "t3": dag_tasks_t3},
|
||||
) as mock_fetch:
|
||||
processing_ids = [r.id for r in results if r.status == "processing"]
|
||||
if processing_ids:
|
||||
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
|
||||
for r in results:
|
||||
if r.id in dag_statuses:
|
||||
r.dag_status = dag_statuses[r.id]
|
||||
|
||||
mock_fetch.assert_called_once_with(["t1", "t3"])
|
||||
|
||||
assert results[0].dag_status == dag_tasks_t1
|
||||
assert results[1].dag_status is None
|
||||
assert results[2].dag_status == dag_tasks_t3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_result_without_dag_events_stays_none(self):
|
||||
"""Processing result with no DAG_STATUS events in DB stays dag_status=None."""
|
||||
results = [self._make_search_result("t1", "processing")]
|
||||
|
||||
with patch.object(
|
||||
search_module,
|
||||
"_fetch_dag_statuses",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
) as mock_fetch:
|
||||
processing_ids = [r.id for r in results if r.status == "processing"]
|
||||
if processing_ids:
|
||||
dag_statuses = await search_module._fetch_dag_statuses(processing_ids)
|
||||
for r in results:
|
||||
if r.id in dag_statuses:
|
||||
r.dag_status = dag_statuses[r.id]
|
||||
|
||||
mock_fetch.assert_called_once_with(["t1"])
|
||||
|
||||
assert results[0].dag_status is None
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from workflows.errors import WorkflowRuntimeError, WorkflowTimeoutError
|
||||
|
||||
from reflector.llm import LLM, LLMParseError, StructuredOutputWorkflow
|
||||
from reflector.settings import Settings
|
||||
from reflector.utils.retry import RetryException
|
||||
|
||||
|
||||
@@ -26,6 +27,57 @@ def make_completion_response(text: str):
|
||||
return response
|
||||
|
||||
|
||||
class TestLLMEnableThinking:
|
||||
"""Test that LLM_ENABLE_THINKING setting is passed through to OpenAILike"""
|
||||
|
||||
def test_enable_thinking_false_passed_in_extra_body(self):
|
||||
"""enable_thinking=False should be in extra_body when LLM_ENABLE_THINKING=False"""
|
||||
settings = Settings(
|
||||
LLM_ENABLE_THINKING=False,
|
||||
LLM_URL="http://fake",
|
||||
LLM_API_KEY="fake",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.OpenAILike") as mock_openai,
|
||||
patch("reflector.llm.Settings"),
|
||||
):
|
||||
LLM(settings=settings)
|
||||
extra_body = mock_openai.call_args.kwargs["additional_kwargs"]["extra_body"]
|
||||
assert extra_body["enable_thinking"] is False
|
||||
|
||||
def test_enable_thinking_true_passed_in_extra_body(self):
|
||||
"""enable_thinking=True should be in extra_body when LLM_ENABLE_THINKING=True"""
|
||||
settings = Settings(
|
||||
LLM_ENABLE_THINKING=True,
|
||||
LLM_URL="http://fake",
|
||||
LLM_API_KEY="fake",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.OpenAILike") as mock_openai,
|
||||
patch("reflector.llm.Settings"),
|
||||
):
|
||||
LLM(settings=settings)
|
||||
extra_body = mock_openai.call_args.kwargs["additional_kwargs"]["extra_body"]
|
||||
assert extra_body["enable_thinking"] is True
|
||||
|
||||
def test_enable_thinking_none_not_in_extra_body(self):
|
||||
"""enable_thinking should not be in extra_body when LLM_ENABLE_THINKING is None (default)"""
|
||||
settings = Settings(
|
||||
LLM_URL="http://fake",
|
||||
LLM_API_KEY="fake",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("reflector.llm.OpenAILike") as mock_openai,
|
||||
patch("reflector.llm.Settings"),
|
||||
):
|
||||
LLM(settings=settings)
|
||||
extra_body = mock_openai.call_args.kwargs["additional_kwargs"]["extra_body"]
|
||||
assert "enable_thinking" not in extra_body
|
||||
|
||||
|
||||
class TestLLMParseErrorRecovery:
|
||||
"""Test parse error recovery with Workflow feedback loop"""
|
||||
|
||||
|
||||
@@ -1,331 +0,0 @@
|
||||
"""WebSocket broadcast delivery tests for STATUS and DAG_STATUS events.
|
||||
|
||||
Tests the full chain identified in DEBUG.md:
|
||||
broadcast_event() → ws_manager.send_json() → Redis/in-memory pub/sub
|
||||
→ _pubsub_data_reader() → socket.send_json() → WebSocket client
|
||||
|
||||
Covers:
|
||||
1. STATUS event delivery to transcript room WS
|
||||
2. DAG_STATUS event delivery to transcript room WS
|
||||
3. Full broadcast_event() chain (requires broadcast.py patching)
|
||||
4. _pubsub_data_reader resilience when a client disconnects
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from httpx_ws import aconnect_ws
|
||||
from uvicorn import Config, Server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def appserver_ws_broadcast(setup_database, monkeypatch):
|
||||
"""Start real uvicorn server for WebSocket broadcast tests.
|
||||
|
||||
Also patches broadcast.py's get_ws_manager (missing from conftest autouse fixture).
|
||||
"""
|
||||
# Patch broadcast.py's get_ws_manager — conftest.py misses this module.
|
||||
# Without this, broadcast_event() creates a real Redis ws_manager.
|
||||
import reflector.ws_manager as ws_mod
|
||||
from reflector.app import app
|
||||
from reflector.db import get_database
|
||||
|
||||
monkeypatch.setattr(
|
||||
"reflector.hatchet.broadcast.get_ws_manager", ws_mod.get_ws_manager
|
||||
)
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = 1259
|
||||
server_started = threading.Event()
|
||||
server_exception = None
|
||||
server_instance = None
|
||||
|
||||
def run_server():
|
||||
nonlocal server_exception, server_instance
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
config = Config(app=app, host=host, port=port, loop=loop)
|
||||
server_instance = Server(config)
|
||||
|
||||
async def start_server():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
await server_instance.serve()
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
server_started.set()
|
||||
loop.run_until_complete(start_server())
|
||||
except Exception as e:
|
||||
server_exception = e
|
||||
server_started.set()
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
server_started.wait(timeout=30)
|
||||
if server_exception:
|
||||
raise server_exception
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
yield host, port
|
||||
|
||||
if server_instance:
|
||||
server_instance.should_exit = True
|
||||
server_thread.join(timeout=2.0)
|
||||
|
||||
from reflector.ws_manager import reset_ws_manager
|
||||
|
||||
reset_ws_manager()
|
||||
|
||||
|
||||
async def _create_transcript(host: str, port: int, name: str) -> str:
|
||||
"""Create a transcript via ASGI transport and return its ID."""
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
|
||||
resp = await ac.post("/transcripts", json={"name": name})
|
||||
assert resp.status_code == 200, f"Failed to create transcript: {resp.text}"
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
async def _drain_historical_events(ws, timeout: float = 0.5) -> list[dict]:
|
||||
"""Read all historical events sent on WS connect (non-blocking drain)."""
|
||||
events = []
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=0.1)
|
||||
events.append(msg)
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
break
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: STATUS event delivery via ws_manager.send_json
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_ws_receives_status_via_send_json(appserver_ws_broadcast):
|
||||
"""STATUS event published via ws_manager.send_json() arrives at transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "Status send_json test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
import reflector.ws_manager as ws_mod
|
||||
|
||||
ws_manager = ws_mod.get_ws_manager()
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={"event": "STATUS", "data": {"value": "processing"}},
|
||||
)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "STATUS"
|
||||
assert msg["data"]["value"] == "processing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: DAG_STATUS event delivery via ws_manager.send_json
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_ws_receives_dag_status_via_send_json(appserver_ws_broadcast):
|
||||
"""DAG_STATUS event published via ws_manager.send_json() arrives at transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "DAG_STATUS send_json test")
|
||||
|
||||
dag_payload = {
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "test-run-123",
|
||||
"tasks": [
|
||||
{
|
||||
"name": "get_recording",
|
||||
"status": "completed",
|
||||
"started_at": "2025-01-01T00:00:00Z",
|
||||
"finished_at": "2025-01-01T00:00:05Z",
|
||||
"duration_seconds": 5.0,
|
||||
"parents": [],
|
||||
"error": None,
|
||||
"children_total": None,
|
||||
"children_completed": None,
|
||||
"progress_pct": None,
|
||||
},
|
||||
{
|
||||
"name": "process_tracks",
|
||||
"status": "running",
|
||||
"started_at": "2025-01-01T00:00:05Z",
|
||||
"finished_at": None,
|
||||
"duration_seconds": None,
|
||||
"parents": ["get_recording"],
|
||||
"error": None,
|
||||
"children_total": 3,
|
||||
"children_completed": 1,
|
||||
"progress_pct": 33.3,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
import reflector.ws_manager as ws_mod
|
||||
|
||||
ws_manager = ws_mod.get_ws_manager()
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message=dag_payload,
|
||||
)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "DAG_STATUS"
|
||||
assert msg["data"]["workflow_run_id"] == "test-run-123"
|
||||
assert len(msg["data"]["tasks"]) == 2
|
||||
assert msg["data"]["tasks"][0]["name"] == "get_recording"
|
||||
assert msg["data"]["tasks"][0]["status"] == "completed"
|
||||
assert msg["data"]["tasks"][1]["name"] == "process_tracks"
|
||||
assert msg["data"]["tasks"][1]["children_completed"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Full broadcast_event() chain for STATUS
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_event_delivers_status_to_transcript_ws(appserver_ws_broadcast):
|
||||
"""broadcast_event() end-to-end: STATUS event reaches transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "broadcast_event STATUS test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
from reflector.db.transcripts import TranscriptEvent
|
||||
from reflector.hatchet.broadcast import broadcast_event
|
||||
from reflector.logger import logger
|
||||
|
||||
log = logger.bind(transcript_id=transcript_id)
|
||||
event = TranscriptEvent(event="STATUS", data={"value": "processing"})
|
||||
await broadcast_event(transcript_id, event, logger=log)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "STATUS"
|
||||
assert msg["data"]["value"] == "processing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Full broadcast_event() chain for DAG_STATUS
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_event_delivers_dag_status_to_transcript_ws(
|
||||
appserver_ws_broadcast,
|
||||
):
|
||||
"""broadcast_event() end-to-end: DAG_STATUS event reaches transcript room WS."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "broadcast_event DAG test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
from reflector.db.transcripts import TranscriptEvent
|
||||
from reflector.hatchet.broadcast import broadcast_event
|
||||
from reflector.logger import logger
|
||||
|
||||
log = logger.bind(transcript_id=transcript_id)
|
||||
event = TranscriptEvent(
|
||||
event="DAG_STATUS",
|
||||
data={
|
||||
"workflow_run_id": "test-run-456",
|
||||
"tasks": [
|
||||
{
|
||||
"name": "get_recording",
|
||||
"status": "running",
|
||||
"started_at": None,
|
||||
"finished_at": None,
|
||||
"duration_seconds": None,
|
||||
"parents": [],
|
||||
"error": None,
|
||||
"children_total": None,
|
||||
"children_completed": None,
|
||||
"progress_pct": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
await broadcast_event(transcript_id, event, logger=log)
|
||||
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
assert msg["event"] == "DAG_STATUS"
|
||||
assert msg["data"]["tasks"][0]["name"] == "get_recording"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: Multiple rapid events arrive in order
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_events_arrive_in_order(appserver_ws_broadcast):
|
||||
"""Multiple STATUS then DAG_STATUS events arrive in correct order."""
|
||||
host, port = appserver_ws_broadcast
|
||||
transcript_id = await _create_transcript(host, port, "ordering test")
|
||||
|
||||
ws_url = f"http://{host}:{port}/v1/transcripts/{transcript_id}/events"
|
||||
async with aconnect_ws(ws_url) as ws:
|
||||
await _drain_historical_events(ws)
|
||||
|
||||
import reflector.ws_manager as ws_mod
|
||||
|
||||
ws_manager = ws_mod.get_ws_manager()
|
||||
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={"event": "STATUS", "data": {"value": "processing"}},
|
||||
)
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={
|
||||
"event": "DAG_STATUS",
|
||||
"data": {"workflow_run_id": "r1", "tasks": []},
|
||||
},
|
||||
)
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={
|
||||
"event": "DAG_STATUS",
|
||||
"data": {
|
||||
"workflow_run_id": "r1",
|
||||
"tasks": [{"name": "a", "status": "running"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
await ws_manager.send_json(
|
||||
room_id=f"ts:{transcript_id}",
|
||||
message={"event": "STATUS", "data": {"value": "ended"}},
|
||||
)
|
||||
|
||||
msgs = []
|
||||
for _ in range(4):
|
||||
msg = await asyncio.wait_for(ws.receive_json(), timeout=5.0)
|
||||
msgs.append(msg)
|
||||
|
||||
assert msgs[0]["event"] == "STATUS"
|
||||
assert msgs[0]["data"]["value"] == "processing"
|
||||
assert msgs[1]["event"] == "DAG_STATUS"
|
||||
assert msgs[1]["data"]["tasks"] == []
|
||||
assert msgs[2]["event"] == "DAG_STATUS"
|
||||
assert len(msgs[2]["data"]["tasks"]) == 1
|
||||
assert msgs[3]["event"] == "STATUS"
|
||||
assert msgs[3]["data"]["value"] == "ended"
|
||||
@@ -1,61 +0,0 @@
|
||||
import React from "react";
|
||||
import { Box, Flex } from "@chakra-ui/react";
|
||||
import type { DagTask } from "../../../lib/UserEventsProvider";
|
||||
|
||||
const pulseKeyframes = `
|
||||
@keyframes dagDotPulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.3; }
|
||||
}
|
||||
`;
|
||||
|
||||
function humanizeTaskName(name: string): string {
|
||||
return name
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(" ");
|
||||
}
|
||||
|
||||
function dotProps(status: DagTask["status"]): Record<string, unknown> {
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return { bg: "green.500" };
|
||||
case "running":
|
||||
return {
|
||||
bg: "blue.500",
|
||||
style: { animation: "dagDotPulse 1.5s ease-in-out infinite" },
|
||||
};
|
||||
case "failed":
|
||||
return { bg: "red.500" };
|
||||
case "cancelled":
|
||||
return { bg: "gray.400" };
|
||||
case "queued":
|
||||
default:
|
||||
return {
|
||||
bg: "transparent",
|
||||
border: "1px solid",
|
||||
borderColor: "gray.400",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export default function DagProgressDots({ tasks }: { tasks: DagTask[] }) {
|
||||
return (
|
||||
<>
|
||||
<style>{pulseKeyframes}</style>
|
||||
<Flex gap="2px" alignItems="center" flexWrap="wrap">
|
||||
{tasks.map((task) => (
|
||||
<Box
|
||||
key={task.name}
|
||||
w="4px"
|
||||
h="4px"
|
||||
borderRadius="full"
|
||||
flexShrink={0}
|
||||
title={humanizeTaskName(task.name)}
|
||||
{...dotProps(task.status)}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import {
|
||||
generateTextFragment,
|
||||
} from "../../../lib/textHighlight";
|
||||
import type { components } from "../../../reflector-api";
|
||||
import type { DagTask } from "../../../lib/UserEventsProvider";
|
||||
|
||||
type SearchResult = components["schemas"]["SearchResult"];
|
||||
type SourceKind = components["schemas"]["SourceKind"];
|
||||
@@ -30,7 +29,6 @@ interface TranscriptCardsProps {
|
||||
isLoading?: boolean;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
dagStatusMap?: Map<string, DagTask[]>;
|
||||
}
|
||||
|
||||
function highlightText(text: string, query: string): React.ReactNode {
|
||||
@@ -104,13 +102,11 @@ function TranscriptCard({
|
||||
query,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
dagStatusMap,
|
||||
}: {
|
||||
result: SearchResult;
|
||||
query: string;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
dagStatusMap?: Map<string, DagTask[]>;
|
||||
}) {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
|
||||
@@ -141,16 +137,7 @@ function TranscriptCard({
|
||||
<Box borderWidth={1} p={4} borderRadius="md" fontSize="sm">
|
||||
<Flex justify="space-between" alignItems="flex-start" gap="2">
|
||||
<Box>
|
||||
<TranscriptStatusIcon
|
||||
status={result.status}
|
||||
dagStatus={
|
||||
dagStatusMap?.get(result.id) ??
|
||||
((result as Record<string, unknown>).dag_status as
|
||||
| DagTask[]
|
||||
| null) ??
|
||||
null
|
||||
}
|
||||
/>
|
||||
<TranscriptStatusIcon status={result.status} />
|
||||
</Box>
|
||||
<Box flex="1">
|
||||
{/* Title with highlighting and text fragment for deep linking */}
|
||||
@@ -297,7 +284,6 @@ export default function TranscriptCards({
|
||||
isLoading,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
dagStatusMap,
|
||||
}: TranscriptCardsProps) {
|
||||
return (
|
||||
<Box position="relative">
|
||||
@@ -329,7 +315,6 @@ export default function TranscriptCards({
|
||||
query={query}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
dagStatusMap={dagStatusMap}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
|
||||
@@ -8,17 +8,13 @@ import {
|
||||
FaGear,
|
||||
} from "react-icons/fa6";
|
||||
import { TranscriptStatus } from "../../../lib/transcript";
|
||||
import type { DagTask } from "../../../lib/UserEventsProvider";
|
||||
import DagProgressDots from "./DagProgressDots";
|
||||
|
||||
interface TranscriptStatusIconProps {
|
||||
status: TranscriptStatus;
|
||||
dagStatus?: DagTask[] | null;
|
||||
}
|
||||
|
||||
export default function TranscriptStatusIcon({
|
||||
status,
|
||||
dagStatus,
|
||||
}: TranscriptStatusIconProps) {
|
||||
switch (status) {
|
||||
case "ended":
|
||||
@@ -40,9 +36,6 @@ export default function TranscriptStatusIcon({
|
||||
</Box>
|
||||
);
|
||||
case "processing":
|
||||
if (dagStatus && dagStatus.length > 0) {
|
||||
return <DagProgressDots tasks={dagStatus} />;
|
||||
}
|
||||
return (
|
||||
<Box as="span" title="Processing in progress">
|
||||
<Icon color="gray.500" as={FaGear} />
|
||||
|
||||
@@ -43,7 +43,6 @@ import DeleteTranscriptDialog from "./_components/DeleteTranscriptDialog";
|
||||
import { formatLocalDate } from "../../lib/time";
|
||||
import { RECORD_A_MEETING_URL } from "../../api/urls";
|
||||
import { useUserName } from "../../lib/useUserName";
|
||||
import { useDagStatusMap } from "../../lib/UserEventsProvider";
|
||||
|
||||
const SEARCH_FORM_QUERY_INPUT_NAME = "query" as const;
|
||||
|
||||
@@ -274,7 +273,6 @@ export default function TranscriptBrowser() {
|
||||
}, [JSON.stringify(searchFilters)]);
|
||||
|
||||
const userName = useUserName();
|
||||
const dagStatusMap = useDagStatusMap();
|
||||
const [deletionLoading, setDeletionLoading] = useState(false);
|
||||
const cancelRef = React.useRef(null);
|
||||
const [transcriptToDeleteId, setTranscriptToDeleteId] =
|
||||
@@ -410,7 +408,6 @@ export default function TranscriptBrowser() {
|
||||
isLoading={searchLoading}
|
||||
onDelete={setTranscriptToDeleteId}
|
||||
onReprocess={handleProcessTranscript}
|
||||
dagStatusMap={dagStatusMap}
|
||||
/>
|
||||
|
||||
{!searchLoading && results.length === 0 && (
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { Table, Box, Icon, Spinner, Text, Badge } from "@chakra-ui/react";
|
||||
import { FaCheck, FaXmark, FaClock, FaMinus } from "react-icons/fa6";
|
||||
import type { DagTask, DagTaskStatus } from "../../useWebSockets";
|
||||
|
||||
function humanizeTaskName(name: string): string {
|
||||
return name
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(" ");
|
||||
}
|
||||
|
||||
function formatDuration(seconds: number): string {
|
||||
if (seconds < 60) {
|
||||
return `${Math.round(seconds)}s`;
|
||||
}
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
const remainingSeconds = Math.round(seconds % 60);
|
||||
return `${minutes}m ${remainingSeconds}s`;
|
||||
}
|
||||
|
||||
function StatusIcon({ status }: { status: DagTaskStatus }) {
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return (
|
||||
<Box as="span" title="Completed">
|
||||
<Icon color="green.500" as={FaCheck} />
|
||||
</Box>
|
||||
);
|
||||
case "running":
|
||||
return <Spinner size="sm" color="blue.500" />;
|
||||
case "failed":
|
||||
return (
|
||||
<Box as="span" title="Failed">
|
||||
<Icon color="red.500" as={FaXmark} />
|
||||
</Box>
|
||||
);
|
||||
case "queued":
|
||||
return (
|
||||
<Box as="span" title="Queued">
|
||||
<Icon color="gray.400" as={FaClock} />
|
||||
</Box>
|
||||
);
|
||||
case "cancelled":
|
||||
return (
|
||||
<Box as="span" title="Cancelled">
|
||||
<Icon color="gray.400" as={FaMinus} />
|
||||
</Box>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function ElapsedTimer({ startedAt }: { startedAt: string }) {
|
||||
const [elapsed, setElapsed] = useState<number>(() => {
|
||||
return (Date.now() - new Date(startedAt).getTime()) / 1000;
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const interval = setInterval(() => {
|
||||
setElapsed((Date.now() - new Date(startedAt).getTime()) / 1000);
|
||||
}, 1000);
|
||||
return () => clearInterval(interval);
|
||||
}, [startedAt]);
|
||||
|
||||
return <Text fontSize="sm">{formatDuration(elapsed)}</Text>;
|
||||
}
|
||||
|
||||
function DurationCell({ task }: { task: DagTask }) {
|
||||
if (task.status === "completed" && task.duration_seconds !== null) {
|
||||
return <Text fontSize="sm">{formatDuration(task.duration_seconds)}</Text>;
|
||||
}
|
||||
if (task.status === "running" && task.started_at) {
|
||||
return <ElapsedTimer startedAt={task.started_at} />;
|
||||
}
|
||||
return (
|
||||
<Text fontSize="sm" color="gray.400">
|
||||
--
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
function ProgressCell({ task }: { task: DagTask }) {
|
||||
if (task.progress_pct === null && task.children_total === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box>
|
||||
{task.progress_pct !== null && (
|
||||
<Box
|
||||
w="100%"
|
||||
h="6px"
|
||||
bg="gray.200"
|
||||
borderRadius="full"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Box
|
||||
h="100%"
|
||||
w={`${Math.min(100, Math.max(0, task.progress_pct))}%`}
|
||||
bg={task.status === "failed" ? "red.400" : "blue.400"}
|
||||
borderRadius="full"
|
||||
transition="width 0.3s ease"
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{task.children_total !== null && (
|
||||
<Badge
|
||||
size="sm"
|
||||
colorPalette="gray"
|
||||
mt={task.progress_pct !== null ? 1 : 0}
|
||||
>
|
||||
{task.children_completed ?? 0}/{task.children_total}
|
||||
</Badge>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
function TaskRow({ task }: { task: DagTask }) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const hasFailed = task.status === "failed" && task.error;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Table.Row
|
||||
cursor={hasFailed ? "pointer" : "default"}
|
||||
onClick={hasFailed ? () => setExpanded((prev) => !prev) : undefined}
|
||||
_hover={hasFailed ? { bg: "gray.50" } : undefined}
|
||||
>
|
||||
<Table.Cell>
|
||||
<Text fontSize="sm" fontWeight="medium">
|
||||
{humanizeTaskName(task.name)}
|
||||
</Text>
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<StatusIcon status={task.status} />
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<DurationCell task={task} />
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<ProgressCell task={task} />
|
||||
</Table.Cell>
|
||||
</Table.Row>
|
||||
{hasFailed && expanded && (
|
||||
<Table.Row>
|
||||
<Table.Cell colSpan={4}>
|
||||
<Box bg="red.50" p={3} borderRadius="md">
|
||||
<Text fontSize="xs" color="red.700" whiteSpace="pre-wrap">
|
||||
{task.error}
|
||||
</Text>
|
||||
</Box>
|
||||
</Table.Cell>
|
||||
</Table.Row>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function DagProgressTable({ tasks }: { tasks: DagTask[] }) {
|
||||
return (
|
||||
<Box w="100%" overflowX="auto">
|
||||
<Table.Root size="sm">
|
||||
<Table.Header>
|
||||
<Table.Row>
|
||||
<Table.ColumnHeader fontWeight="600">Task</Table.ColumnHeader>
|
||||
<Table.ColumnHeader fontWeight="600" width="80px">
|
||||
Status
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader fontWeight="600" width="100px">
|
||||
Duration
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader fontWeight="600" width="140px">
|
||||
Progress
|
||||
</Table.ColumnHeader>
|
||||
</Table.Row>
|
||||
</Table.Header>
|
||||
<Table.Body>
|
||||
{tasks.map((task) => (
|
||||
<TaskRow key={task.name} task={task} />
|
||||
))}
|
||||
</Table.Body>
|
||||
</Table.Root>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -11,10 +11,6 @@ import {
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useTranscriptGet } from "../../../../lib/apiHooks";
|
||||
import { parseNonEmptyString } from "../../../../lib/utils";
|
||||
import { useWebSockets } from "../../useWebSockets";
|
||||
import type { DagTask } from "../../useWebSockets";
|
||||
import { useDagStatusMap } from "../../../../lib/UserEventsProvider";
|
||||
import DagProgressTable from "./DagProgressTable";
|
||||
|
||||
type TranscriptProcessing = {
|
||||
params: Promise<{
|
||||
@@ -28,21 +24,9 @@ export default function TranscriptProcessing(details: TranscriptProcessing) {
|
||||
const router = useRouter();
|
||||
|
||||
const transcript = useTranscriptGet(transcriptId);
|
||||
const { status: wsStatus, dagStatus: wsDagStatus } =
|
||||
useWebSockets(transcriptId);
|
||||
const userDagStatusMap = useDagStatusMap();
|
||||
const userDagStatus = userDagStatusMap.get(transcriptId) ?? null;
|
||||
|
||||
const restDagStatus: DagTask[] | null =
|
||||
((transcript.data as Record<string, unknown>)?.dag_status as
|
||||
| DagTask[]
|
||||
| null) ?? null;
|
||||
|
||||
// Prefer transcript room WS (most granular), then user room WS, then REST
|
||||
const dagStatus = wsDagStatus ?? userDagStatus ?? restDagStatus;
|
||||
|
||||
useEffect(() => {
|
||||
const status = wsStatus?.value ?? transcript.data?.status;
|
||||
const status = transcript.data?.status;
|
||||
if (!status) return;
|
||||
|
||||
if (status === "ended" || status === "error") {
|
||||
@@ -57,7 +41,6 @@ export default function TranscriptProcessing(details: TranscriptProcessing) {
|
||||
router.replace(dest);
|
||||
}
|
||||
}, [
|
||||
wsStatus?.value,
|
||||
transcript.data?.status,
|
||||
transcript.data?.source_kind,
|
||||
router,
|
||||
@@ -91,29 +74,11 @@ export default function TranscriptProcessing(details: TranscriptProcessing) {
|
||||
w={{ base: "full", md: "container.xl" }}
|
||||
>
|
||||
<Center h={"full"} w="full">
|
||||
<VStack
|
||||
gap={10}
|
||||
bg="gray.100"
|
||||
p={10}
|
||||
borderRadius="md"
|
||||
maxW="600px"
|
||||
w="full"
|
||||
>
|
||||
{dagStatus ? (
|
||||
<>
|
||||
<Heading size={"md"} textAlign="center">
|
||||
Processing recording
|
||||
</Heading>
|
||||
<DagProgressTable tasks={dagStatus} />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Spinner size="xl" color="blue.500" />
|
||||
<Heading size={"md"} textAlign="center">
|
||||
Processing recording
|
||||
</Heading>
|
||||
</>
|
||||
)}
|
||||
<VStack gap={10} bg="gray.100" p={10} borderRadius="md" maxW="500px">
|
||||
<Spinner size="xl" color="blue.500" />
|
||||
<Heading size={"md"} textAlign="center">
|
||||
Processing recording
|
||||
</Heading>
|
||||
<Text color="gray.600" textAlign="center">
|
||||
You can safely return to the library while your recording is being
|
||||
processed.
|
||||
|
||||
@@ -14,9 +14,6 @@ import {
|
||||
} from "../../lib/apiHooks";
|
||||
import { NonEmptyString } from "../../lib/utils";
|
||||
|
||||
import type { DagTask } from "../../lib/dagTypes";
|
||||
export type { DagTask, DagTaskStatus } from "../../lib/dagTypes";
|
||||
|
||||
export type UseWebSockets = {
|
||||
transcriptTextLive: string;
|
||||
translateText: string;
|
||||
@@ -27,7 +24,6 @@ export type UseWebSockets = {
|
||||
status: Status | null;
|
||||
waveform: AudioWaveform | null;
|
||||
duration: number | null;
|
||||
dagStatus: DagTask[] | null;
|
||||
};
|
||||
|
||||
export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
@@ -44,7 +40,6 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
summary: "",
|
||||
});
|
||||
const [status, setStatus] = useState<Status | null>(null);
|
||||
const [dagStatus, setDagStatus] = useState<DagTask[] | null>(null);
|
||||
const { setError } = useError();
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
@@ -436,31 +431,11 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
);
|
||||
}
|
||||
setStatus(message.data);
|
||||
invalidateTranscript(queryClient, transcriptId as NonEmptyString);
|
||||
if (message.data.value === "ended") {
|
||||
ws.close();
|
||||
}
|
||||
break;
|
||||
|
||||
case "DAG_STATUS":
|
||||
if (message.data?.tasks) {
|
||||
setDagStatus(message.data.tasks);
|
||||
}
|
||||
break;
|
||||
|
||||
case "DAG_TASK_PROGRESS":
|
||||
if (message.data) {
|
||||
setDagStatus(
|
||||
(prev) =>
|
||||
prev?.map((t) =>
|
||||
t.name === message.data.task_name
|
||||
? { ...t, progress_pct: message.data.progress_pct }
|
||||
: t,
|
||||
) ?? null,
|
||||
);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
setError(
|
||||
new Error(`Received unknown WebSocket event: ${message.event}`),
|
||||
@@ -518,6 +493,5 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
||||
status,
|
||||
waveform,
|
||||
duration,
|
||||
dagStatus,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,25 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import React, { useEffect, useRef } from "react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { WEBSOCKET_URL } from "./apiClient";
|
||||
import { useAuth } from "./AuthProvider";
|
||||
import { z } from "zod";
|
||||
import {
|
||||
invalidateTranscript,
|
||||
invalidateTranscriptLists,
|
||||
TRANSCRIPT_SEARCH_URL,
|
||||
} from "./apiHooks";
|
||||
import type { NonEmptyString } from "./utils";
|
||||
|
||||
import type { DagTask } from "./dagTypes";
|
||||
export type { DagTask, DagTaskStatus } from "./dagTypes";
|
||||
|
||||
const DagStatusContext = React.createContext<Map<string, DagTask[]>>(new Map());
|
||||
|
||||
export function useDagStatusMap() {
|
||||
return React.useContext(DagStatusContext);
|
||||
}
|
||||
import { invalidateTranscriptLists, TRANSCRIPT_SEARCH_URL } from "./apiHooks";
|
||||
|
||||
const UserEvent = z.object({
|
||||
event: z.string(),
|
||||
@@ -109,9 +95,6 @@ export function UserEventsProvider({
|
||||
const queryClient = useQueryClient();
|
||||
const tokenRef = useRef<string | null>(null);
|
||||
const detachRef = useRef<(() => void) | null>(null);
|
||||
const [dagStatusMap, setDagStatusMap] = useState<Map<string, DagTask[]>>(
|
||||
new Map(),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// Only tear down when the user is truly unauthenticated
|
||||
@@ -150,52 +133,20 @@ export function UserEventsProvider({
|
||||
if (!detachRef.current) {
|
||||
const onMessage = (event: MessageEvent) => {
|
||||
try {
|
||||
const fullMsg = JSON.parse(event.data);
|
||||
const msg = UserEvent.parse(fullMsg);
|
||||
const msg = UserEvent.parse(JSON.parse(event.data));
|
||||
const eventName = msg.event;
|
||||
|
||||
const invalidateList = () => invalidateTranscriptLists(queryClient);
|
||||
|
||||
switch (eventName) {
|
||||
case "TRANSCRIPT_CREATED":
|
||||
case "TRANSCRIPT_DELETED":
|
||||
case "TRANSCRIPT_STATUS":
|
||||
case "TRANSCRIPT_FINAL_TITLE":
|
||||
case "TRANSCRIPT_DURATION":
|
||||
invalidateList().then(() => {});
|
||||
break;
|
||||
|
||||
case "TRANSCRIPT_STATUS": {
|
||||
invalidateList().then(() => {});
|
||||
const transcriptId = fullMsg.data?.id as string | undefined;
|
||||
if (transcriptId) {
|
||||
invalidateTranscript(
|
||||
queryClient,
|
||||
transcriptId as NonEmptyString,
|
||||
).then(() => {});
|
||||
}
|
||||
const status = fullMsg.data?.value as string | undefined;
|
||||
if (transcriptId && status && status !== "processing") {
|
||||
setDagStatusMap((prev) => {
|
||||
const next = new Map(prev);
|
||||
next.delete(transcriptId);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "TRANSCRIPT_DAG_STATUS": {
|
||||
const transcriptId = fullMsg.data?.id as string | undefined;
|
||||
const tasks = fullMsg.data?.tasks as DagTask[] | undefined;
|
||||
if (transcriptId && tasks) {
|
||||
setDagStatusMap((prev) => {
|
||||
const next = new Map(prev);
|
||||
next.set(transcriptId, tasks);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
// Ignore other content events for list updates
|
||||
break;
|
||||
@@ -225,9 +176,5 @@ export function UserEventsProvider({
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<DagStatusContext.Provider value={dagStatusMap}>
|
||||
{children}
|
||||
</DagStatusContext.Provider>
|
||||
);
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
export type DagTaskStatus =
|
||||
| "queued"
|
||||
| "running"
|
||||
| "completed"
|
||||
| "failed"
|
||||
| "cancelled";
|
||||
|
||||
export type DagTask = {
|
||||
name: string;
|
||||
status: DagTaskStatus;
|
||||
started_at: string | null;
|
||||
finished_at: string | null;
|
||||
duration_seconds: number | null;
|
||||
parents: string[];
|
||||
error: string | null;
|
||||
children_total: number | null;
|
||||
children_completed: number | null;
|
||||
progress_pct: number | null;
|
||||
};
|
||||
Reference in New Issue
Block a user