From 025e6da5393d1538b2b519c15f6bfa75a00c0016 Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Mon, 9 Feb 2026 12:56:46 -0500 Subject: [PATCH] feat: add dag_status REST enrichment to search and transcript GET --- server/reflector/db/search.py | 48 +++++++++ server/reflector/views/transcripts.py | 9 ++ server/tests/test_dag_progress_rest.py | 144 +++++++++++++++++++++++++ 3 files changed, 201 insertions(+) create mode 100644 server/tests/test_dag_progress_rest.py diff --git a/server/reflector/db/search.py b/server/reflector/db/search.py index 5d9bc507..6ddbc656 100644 --- a/server/reflector/db/search.py +++ b/server/reflector/db/search.py @@ -1,6 +1,7 @@ """Search functionality for transcripts and other entities.""" import itertools +import json from dataclasses import dataclass from datetime import datetime from io import StringIO @@ -172,6 +173,9 @@ class SearchResult(BaseModel): total_match_count: NonNegativeInt = Field( default=0, description="Total number of matches found in the transcript" ) + dag_status: list[dict] | None = Field( + default=None, description="Latest DAG task status for processing transcripts" + ) @field_serializer("created_at", when_used="json") def serialize_datetime(self, dt: datetime) -> str: @@ -328,6 +332,42 @@ class SnippetGenerator: return summary_snippets + webvtt_snippets, total_matches +async def _fetch_dag_statuses(transcript_ids: list[str]) -> dict[str, list[dict]]: + """Fetch latest DAG_STATUS event data for given transcript IDs. + + Returns dict mapping transcript_id -> tasks list from the last DAG_STATUS event. + """ + if not transcript_ids: + return {} + + db = get_database() + query = sqlalchemy.select( + [ + transcripts.c.id, + transcripts.c.events, + ] + ).where(transcripts.c.id.in_(transcript_ids)) + + rows = await db.fetch_all(query) + result: dict[str, list[dict]] = {} + + for row in rows: + events_raw = row["events"] + if not events_raw: + continue + # events is stored as JSON list + events = events_raw if isinstance(events_raw, list) else json.loads(events_raw) + # Find last DAG_STATUS event + for ev in reversed(events): + if isinstance(ev, dict) and ev.get("event") == "DAG_STATUS": + tasks = ev.get("data", {}).get("tasks") + if tasks: + result[row["id"]] = tasks + break + + return result + + class SearchController: """Controller for search operations across different entities.""" @@ -470,6 +510,14 @@ class SearchController: logger.error(f"Error processing search results: {e}", exc_info=True) raise + # Enrich processing transcripts with DAG status + processing_ids = [r.id for r in results if r.status == "processing"] + if processing_ids: + dag_statuses = await _fetch_dag_statuses(processing_ids) + for r in results: + if r.id in dag_statuses: + r.dag_status = dag_statuses[r.id] + return results, total diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 2e1c9d30..afda77f6 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -111,6 +111,7 @@ class GetTranscriptMinimal(BaseModel): room_id: str | None = None room_name: str | None = None audio_deleted: bool | None = None + dag_status: list[dict] | None = None class TranscriptParticipantWithEmail(TranscriptParticipant): @@ -491,6 +492,13 @@ async def transcript_get( ) ) + dag_status = None + if transcript.status == "processing" and transcript.events: + for ev in reversed(transcript.events): + if ev.event == "DAG_STATUS": + dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None + break + base_data = { "id": transcript.id, "user_id": transcript.user_id, @@ -512,6 +520,7 @@ async def transcript_get( "room_id": transcript.room_id, "room_name": room_name, "audio_deleted": transcript.audio_deleted, + "dag_status": dag_status, "participants": participants, } diff --git a/server/tests/test_dag_progress_rest.py b/server/tests/test_dag_progress_rest.py new file mode 100644 index 00000000..030ed0ec --- /dev/null +++ b/server/tests/test_dag_progress_rest.py @@ -0,0 +1,144 @@ +"""Tests for DAG status REST enrichment on search and transcript GET endpoints.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from reflector.db.search import _fetch_dag_statuses + + +class TestFetchDagStatuses: + """Test the _fetch_dag_statuses helper.""" + + @pytest.mark.asyncio + async def test_returns_empty_for_empty_ids(self): + result = await _fetch_dag_statuses([]) + assert result == {} + + @pytest.mark.asyncio + async def test_extracts_last_dag_status(self): + events = [ + {"event": "STATUS", "data": {"value": "processing"}}, + { + "event": "DAG_STATUS", + "data": { + "workflow_run_id": "r1", + "tasks": [{"name": "get_recording", "status": "completed"}], + }, + }, + { + "event": "DAG_STATUS", + "data": { + "workflow_run_id": "r1", + "tasks": [ + {"name": "get_recording", "status": "completed"}, + {"name": "process_tracks", "status": "running"}, + ], + }, + }, + ] + mock_row = {"id": "t1", "events": events} + + with patch("reflector.db.search.get_database") as mock_db: + mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row]) + result = await _fetch_dag_statuses(["t1"]) + + assert "t1" in result + assert len(result["t1"]) == 2 # Last DAG_STATUS had 2 tasks + + @pytest.mark.asyncio + async def test_skips_transcripts_without_events(self): + mock_row = {"id": "t1", "events": None} + + with patch("reflector.db.search.get_database") as mock_db: + mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row]) + result = await _fetch_dag_statuses(["t1"]) + + assert result == {} + + @pytest.mark.asyncio + async def test_skips_transcripts_without_dag_status(self): + events = [ + {"event": "STATUS", "data": {"value": "processing"}}, + {"event": "DURATION", "data": {"duration": 1000}}, + ] + mock_row = {"id": "t1", "events": events} + + with patch("reflector.db.search.get_database") as mock_db: + mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row]) + result = await _fetch_dag_statuses(["t1"]) + + assert result == {} + + @pytest.mark.asyncio + async def test_handles_json_string_events(self): + """Events stored as JSON string rather than already-parsed list.""" + import json + + events = [ + { + "event": "DAG_STATUS", + "data": { + "workflow_run_id": "r1", + "tasks": [{"name": "transcribe", "status": "running"}], + }, + }, + ] + mock_row = {"id": "t1", "events": json.dumps(events)} + + with patch("reflector.db.search.get_database") as mock_db: + mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row]) + result = await _fetch_dag_statuses(["t1"]) + + assert "t1" in result + assert len(result["t1"]) == 1 + assert result["t1"][0]["name"] == "transcribe" + + @pytest.mark.asyncio + async def test_multiple_transcripts(self): + """Handles multiple transcripts in one call.""" + events_t1 = [ + { + "event": "DAG_STATUS", + "data": { + "workflow_run_id": "r1", + "tasks": [{"name": "a", "status": "completed"}], + }, + }, + ] + events_t2 = [ + { + "event": "DAG_STATUS", + "data": { + "workflow_run_id": "r2", + "tasks": [{"name": "b", "status": "running"}], + }, + }, + ] + mock_rows = [ + {"id": "t1", "events": events_t1}, + {"id": "t2", "events": events_t2}, + ] + + with patch("reflector.db.search.get_database") as mock_db: + mock_db.return_value.fetch_all = AsyncMock(return_value=mock_rows) + result = await _fetch_dag_statuses(["t1", "t2"]) + + assert "t1" in result + assert "t2" in result + assert result["t1"][0]["name"] == "a" + assert result["t2"][0]["name"] == "b" + + @pytest.mark.asyncio + async def test_dag_status_without_tasks_key_skipped(self): + """DAG_STATUS event with no tasks key in data should be skipped.""" + events = [ + {"event": "DAG_STATUS", "data": {"workflow_run_id": "r1"}}, + ] + mock_row = {"id": "t1", "events": events} + + with patch("reflector.db.search.get_database") as mock_db: + mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row]) + result = await _fetch_dag_statuses(["t1"]) + + assert result == {}