feat: add dag_status REST enrichment to search and transcript GET

This commit is contained in:
Igor Loskutov
2026-02-09 12:56:46 -05:00
parent 4b79b0c989
commit 025e6da539
3 changed files with 201 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
"""Search functionality for transcripts and other entities.""" """Search functionality for transcripts and other entities."""
import itertools import itertools
import json
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from io import StringIO from io import StringIO
@@ -172,6 +173,9 @@ class SearchResult(BaseModel):
total_match_count: NonNegativeInt = Field( total_match_count: NonNegativeInt = Field(
default=0, description="Total number of matches found in the transcript" default=0, description="Total number of matches found in the transcript"
) )
dag_status: list[dict] | None = Field(
default=None, description="Latest DAG task status for processing transcripts"
)
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str: def serialize_datetime(self, dt: datetime) -> str:
@@ -328,6 +332,42 @@ class SnippetGenerator:
return summary_snippets + webvtt_snippets, total_matches return summary_snippets + webvtt_snippets, total_matches
async def _fetch_dag_statuses(transcript_ids: list[str]) -> dict[str, list[dict]]:
"""Fetch latest DAG_STATUS event data for given transcript IDs.
Returns dict mapping transcript_id -> tasks list from the last DAG_STATUS event.
"""
if not transcript_ids:
return {}
db = get_database()
query = sqlalchemy.select(
[
transcripts.c.id,
transcripts.c.events,
]
).where(transcripts.c.id.in_(transcript_ids))
rows = await db.fetch_all(query)
result: dict[str, list[dict]] = {}
for row in rows:
events_raw = row["events"]
if not events_raw:
continue
# events is stored as JSON list
events = events_raw if isinstance(events_raw, list) else json.loads(events_raw)
# Find last DAG_STATUS event
for ev in reversed(events):
if isinstance(ev, dict) and ev.get("event") == "DAG_STATUS":
tasks = ev.get("data", {}).get("tasks")
if tasks:
result[row["id"]] = tasks
break
return result
class SearchController: class SearchController:
"""Controller for search operations across different entities.""" """Controller for search operations across different entities."""
@@ -470,6 +510,14 @@ class SearchController:
logger.error(f"Error processing search results: {e}", exc_info=True) logger.error(f"Error processing search results: {e}", exc_info=True)
raise raise
# Enrich processing transcripts with DAG status
processing_ids = [r.id for r in results if r.status == "processing"]
if processing_ids:
dag_statuses = await _fetch_dag_statuses(processing_ids)
for r in results:
if r.id in dag_statuses:
r.dag_status = dag_statuses[r.id]
return results, total return results, total

View File

@@ -111,6 +111,7 @@ class GetTranscriptMinimal(BaseModel):
room_id: str | None = None room_id: str | None = None
room_name: str | None = None room_name: str | None = None
audio_deleted: bool | None = None audio_deleted: bool | None = None
dag_status: list[dict] | None = None
class TranscriptParticipantWithEmail(TranscriptParticipant): class TranscriptParticipantWithEmail(TranscriptParticipant):
@@ -491,6 +492,13 @@ async def transcript_get(
) )
) )
dag_status = None
if transcript.status == "processing" and transcript.events:
for ev in reversed(transcript.events):
if ev.event == "DAG_STATUS":
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
break
base_data = { base_data = {
"id": transcript.id, "id": transcript.id,
"user_id": transcript.user_id, "user_id": transcript.user_id,
@@ -512,6 +520,7 @@ async def transcript_get(
"room_id": transcript.room_id, "room_id": transcript.room_id,
"room_name": room_name, "room_name": room_name,
"audio_deleted": transcript.audio_deleted, "audio_deleted": transcript.audio_deleted,
"dag_status": dag_status,
"participants": participants, "participants": participants,
} }

View File

@@ -0,0 +1,144 @@
"""Tests for DAG status REST enrichment on search and transcript GET endpoints."""
from unittest.mock import AsyncMock, patch
import pytest
from reflector.db.search import _fetch_dag_statuses
class TestFetchDagStatuses:
"""Test the _fetch_dag_statuses helper."""
@pytest.mark.asyncio
async def test_returns_empty_for_empty_ids(self):
result = await _fetch_dag_statuses([])
assert result == {}
@pytest.mark.asyncio
async def test_extracts_last_dag_status(self):
events = [
{"event": "STATUS", "data": {"value": "processing"}},
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [{"name": "get_recording", "status": "completed"}],
},
},
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [
{"name": "get_recording", "status": "completed"},
{"name": "process_tracks", "status": "running"},
],
},
},
]
mock_row = {"id": "t1", "events": events}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert "t1" in result
assert len(result["t1"]) == 2 # Last DAG_STATUS had 2 tasks
@pytest.mark.asyncio
async def test_skips_transcripts_without_events(self):
mock_row = {"id": "t1", "events": None}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert result == {}
@pytest.mark.asyncio
async def test_skips_transcripts_without_dag_status(self):
events = [
{"event": "STATUS", "data": {"value": "processing"}},
{"event": "DURATION", "data": {"duration": 1000}},
]
mock_row = {"id": "t1", "events": events}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert result == {}
@pytest.mark.asyncio
async def test_handles_json_string_events(self):
"""Events stored as JSON string rather than already-parsed list."""
import json
events = [
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [{"name": "transcribe", "status": "running"}],
},
},
]
mock_row = {"id": "t1", "events": json.dumps(events)}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert "t1" in result
assert len(result["t1"]) == 1
assert result["t1"][0]["name"] == "transcribe"
@pytest.mark.asyncio
async def test_multiple_transcripts(self):
"""Handles multiple transcripts in one call."""
events_t1 = [
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r1",
"tasks": [{"name": "a", "status": "completed"}],
},
},
]
events_t2 = [
{
"event": "DAG_STATUS",
"data": {
"workflow_run_id": "r2",
"tasks": [{"name": "b", "status": "running"}],
},
},
]
mock_rows = [
{"id": "t1", "events": events_t1},
{"id": "t2", "events": events_t2},
]
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=mock_rows)
result = await _fetch_dag_statuses(["t1", "t2"])
assert "t1" in result
assert "t2" in result
assert result["t1"][0]["name"] == "a"
assert result["t2"][0]["name"] == "b"
@pytest.mark.asyncio
async def test_dag_status_without_tasks_key_skipped(self):
"""DAG_STATUS event with no tasks key in data should be skipped."""
events = [
{"event": "DAG_STATUS", "data": {"workflow_run_id": "r1"}},
]
mock_row = {"id": "t1", "events": events}
with patch("reflector.db.search.get_database") as mock_db:
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
result = await _fetch_dag_statuses(["t1"])
assert result == {}