mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-25 22:55:18 +00:00
feat: add dag_status REST enrichment to search and transcript GET
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
"""Search functionality for transcripts and other entities."""
|
"""Search functionality for transcripts and other entities."""
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@@ -172,6 +173,9 @@ class SearchResult(BaseModel):
|
|||||||
total_match_count: NonNegativeInt = Field(
|
total_match_count: NonNegativeInt = Field(
|
||||||
default=0, description="Total number of matches found in the transcript"
|
default=0, description="Total number of matches found in the transcript"
|
||||||
)
|
)
|
||||||
|
dag_status: list[dict] | None = Field(
|
||||||
|
default=None, description="Latest DAG task status for processing transcripts"
|
||||||
|
)
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def serialize_datetime(self, dt: datetime) -> str:
|
def serialize_datetime(self, dt: datetime) -> str:
|
||||||
@@ -328,6 +332,42 @@ class SnippetGenerator:
|
|||||||
return summary_snippets + webvtt_snippets, total_matches
|
return summary_snippets + webvtt_snippets, total_matches
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_dag_statuses(transcript_ids: list[str]) -> dict[str, list[dict]]:
|
||||||
|
"""Fetch latest DAG_STATUS event data for given transcript IDs.
|
||||||
|
|
||||||
|
Returns dict mapping transcript_id -> tasks list from the last DAG_STATUS event.
|
||||||
|
"""
|
||||||
|
if not transcript_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db = get_database()
|
||||||
|
query = sqlalchemy.select(
|
||||||
|
[
|
||||||
|
transcripts.c.id,
|
||||||
|
transcripts.c.events,
|
||||||
|
]
|
||||||
|
).where(transcripts.c.id.in_(transcript_ids))
|
||||||
|
|
||||||
|
rows = await db.fetch_all(query)
|
||||||
|
result: dict[str, list[dict]] = {}
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
events_raw = row["events"]
|
||||||
|
if not events_raw:
|
||||||
|
continue
|
||||||
|
# events is stored as JSON list
|
||||||
|
events = events_raw if isinstance(events_raw, list) else json.loads(events_raw)
|
||||||
|
# Find last DAG_STATUS event
|
||||||
|
for ev in reversed(events):
|
||||||
|
if isinstance(ev, dict) and ev.get("event") == "DAG_STATUS":
|
||||||
|
tasks = ev.get("data", {}).get("tasks")
|
||||||
|
if tasks:
|
||||||
|
result[row["id"]] = tasks
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SearchController:
|
class SearchController:
|
||||||
"""Controller for search operations across different entities."""
|
"""Controller for search operations across different entities."""
|
||||||
|
|
||||||
@@ -470,6 +510,14 @@ class SearchController:
|
|||||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Enrich processing transcripts with DAG status
|
||||||
|
processing_ids = [r.id for r in results if r.status == "processing"]
|
||||||
|
if processing_ids:
|
||||||
|
dag_statuses = await _fetch_dag_statuses(processing_ids)
|
||||||
|
for r in results:
|
||||||
|
if r.id in dag_statuses:
|
||||||
|
r.dag_status = dag_statuses[r.id]
|
||||||
|
|
||||||
return results, total
|
return results, total
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class GetTranscriptMinimal(BaseModel):
|
|||||||
room_id: str | None = None
|
room_id: str | None = None
|
||||||
room_name: str | None = None
|
room_name: str | None = None
|
||||||
audio_deleted: bool | None = None
|
audio_deleted: bool | None = None
|
||||||
|
dag_status: list[dict] | None = None
|
||||||
|
|
||||||
|
|
||||||
class TranscriptParticipantWithEmail(TranscriptParticipant):
|
class TranscriptParticipantWithEmail(TranscriptParticipant):
|
||||||
@@ -491,6 +492,13 @@ async def transcript_get(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dag_status = None
|
||||||
|
if transcript.status == "processing" and transcript.events:
|
||||||
|
for ev in reversed(transcript.events):
|
||||||
|
if ev.event == "DAG_STATUS":
|
||||||
|
dag_status = ev.data.get("tasks") if isinstance(ev.data, dict) else None
|
||||||
|
break
|
||||||
|
|
||||||
base_data = {
|
base_data = {
|
||||||
"id": transcript.id,
|
"id": transcript.id,
|
||||||
"user_id": transcript.user_id,
|
"user_id": transcript.user_id,
|
||||||
@@ -512,6 +520,7 @@ async def transcript_get(
|
|||||||
"room_id": transcript.room_id,
|
"room_id": transcript.room_id,
|
||||||
"room_name": room_name,
|
"room_name": room_name,
|
||||||
"audio_deleted": transcript.audio_deleted,
|
"audio_deleted": transcript.audio_deleted,
|
||||||
|
"dag_status": dag_status,
|
||||||
"participants": participants,
|
"participants": participants,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
144
server/tests/test_dag_progress_rest.py
Normal file
144
server/tests/test_dag_progress_rest.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Tests for DAG status REST enrichment on search and transcript GET endpoints."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.db.search import _fetch_dag_statuses
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchDagStatuses:
|
||||||
|
"""Test the _fetch_dag_statuses helper."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_empty_for_empty_ids(self):
|
||||||
|
result = await _fetch_dag_statuses([])
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extracts_last_dag_status(self):
|
||||||
|
events = [
|
||||||
|
{"event": "STATUS", "data": {"value": "processing"}},
|
||||||
|
{
|
||||||
|
"event": "DAG_STATUS",
|
||||||
|
"data": {
|
||||||
|
"workflow_run_id": "r1",
|
||||||
|
"tasks": [{"name": "get_recording", "status": "completed"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"event": "DAG_STATUS",
|
||||||
|
"data": {
|
||||||
|
"workflow_run_id": "r1",
|
||||||
|
"tasks": [
|
||||||
|
{"name": "get_recording", "status": "completed"},
|
||||||
|
{"name": "process_tracks", "status": "running"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_row = {"id": "t1", "events": events}
|
||||||
|
|
||||||
|
with patch("reflector.db.search.get_database") as mock_db:
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||||
|
result = await _fetch_dag_statuses(["t1"])
|
||||||
|
|
||||||
|
assert "t1" in result
|
||||||
|
assert len(result["t1"]) == 2 # Last DAG_STATUS had 2 tasks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_transcripts_without_events(self):
|
||||||
|
mock_row = {"id": "t1", "events": None}
|
||||||
|
|
||||||
|
with patch("reflector.db.search.get_database") as mock_db:
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||||
|
result = await _fetch_dag_statuses(["t1"])
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_transcripts_without_dag_status(self):
|
||||||
|
events = [
|
||||||
|
{"event": "STATUS", "data": {"value": "processing"}},
|
||||||
|
{"event": "DURATION", "data": {"duration": 1000}},
|
||||||
|
]
|
||||||
|
mock_row = {"id": "t1", "events": events}
|
||||||
|
|
||||||
|
with patch("reflector.db.search.get_database") as mock_db:
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||||
|
result = await _fetch_dag_statuses(["t1"])
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_json_string_events(self):
|
||||||
|
"""Events stored as JSON string rather than already-parsed list."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
events = [
|
||||||
|
{
|
||||||
|
"event": "DAG_STATUS",
|
||||||
|
"data": {
|
||||||
|
"workflow_run_id": "r1",
|
||||||
|
"tasks": [{"name": "transcribe", "status": "running"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_row = {"id": "t1", "events": json.dumps(events)}
|
||||||
|
|
||||||
|
with patch("reflector.db.search.get_database") as mock_db:
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||||
|
result = await _fetch_dag_statuses(["t1"])
|
||||||
|
|
||||||
|
assert "t1" in result
|
||||||
|
assert len(result["t1"]) == 1
|
||||||
|
assert result["t1"][0]["name"] == "transcribe"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_transcripts(self):
|
||||||
|
"""Handles multiple transcripts in one call."""
|
||||||
|
events_t1 = [
|
||||||
|
{
|
||||||
|
"event": "DAG_STATUS",
|
||||||
|
"data": {
|
||||||
|
"workflow_run_id": "r1",
|
||||||
|
"tasks": [{"name": "a", "status": "completed"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
events_t2 = [
|
||||||
|
{
|
||||||
|
"event": "DAG_STATUS",
|
||||||
|
"data": {
|
||||||
|
"workflow_run_id": "r2",
|
||||||
|
"tasks": [{"name": "b", "status": "running"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_rows = [
|
||||||
|
{"id": "t1", "events": events_t1},
|
||||||
|
{"id": "t2", "events": events_t2},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("reflector.db.search.get_database") as mock_db:
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=mock_rows)
|
||||||
|
result = await _fetch_dag_statuses(["t1", "t2"])
|
||||||
|
|
||||||
|
assert "t1" in result
|
||||||
|
assert "t2" in result
|
||||||
|
assert result["t1"][0]["name"] == "a"
|
||||||
|
assert result["t2"][0]["name"] == "b"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dag_status_without_tasks_key_skipped(self):
|
||||||
|
"""DAG_STATUS event with no tasks key in data should be skipped."""
|
||||||
|
events = [
|
||||||
|
{"event": "DAG_STATUS", "data": {"workflow_run_id": "r1"}},
|
||||||
|
]
|
||||||
|
mock_row = {"id": "t1", "events": events}
|
||||||
|
|
||||||
|
with patch("reflector.db.search.get_database") as mock_db:
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||||
|
result = await _fetch_dag_statuses(["t1"])
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
Reference in New Issue
Block a user