mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
* feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine * fix: always force reprocessing * fix: ci tests with live pipelines * fix: ci tests with live pipelines
234 lines
7.8 KiB
Python
234 lines
7.8 KiB
Python
"""
|
|
Tests for the FilePipeline Hatchet workflow.
|
|
|
|
Tests verify:
|
|
1. with_error_handling behavior for file pipeline input model
|
|
2. on_workflow_failure logic (don't overwrite 'ended' status)
|
|
3. Input model validation
|
|
"""
|
|
|
|
from contextlib import asynccontextmanager
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
from hatchet_sdk import NonRetryableException
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _noop_db_context():
|
|
"""Async context manager that yields without touching the DB."""
|
|
yield None
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def file_pipeline_module():
|
|
"""Import file_pipeline with Hatchet client mocked."""
|
|
mock_client = MagicMock()
|
|
mock_client.workflow.return_value = MagicMock()
|
|
with patch(
|
|
"reflector.hatchet.client.HatchetClientManager.get_client",
|
|
return_value=mock_client,
|
|
):
|
|
from reflector.hatchet.workflows import file_pipeline
|
|
|
|
return file_pipeline
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_file_input():
|
|
"""Minimal FilePipelineInput for tests."""
|
|
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
|
|
|
return FilePipelineInput(
|
|
transcript_id="ts-file-123",
|
|
room_id="room-456",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_ctx():
|
|
"""Minimal Context-like object."""
|
|
ctx = MagicMock()
|
|
ctx.log = MagicMock()
|
|
return ctx
|
|
|
|
|
|
def test_file_pipeline_input_model():
|
|
"""Test FilePipelineInput validation."""
|
|
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
|
|
|
# Valid input with room_id
|
|
input_with_room = FilePipelineInput(transcript_id="ts-123", room_id="room-456")
|
|
assert input_with_room.transcript_id == "ts-123"
|
|
assert input_with_room.room_id == "room-456"
|
|
|
|
# Valid input without room_id
|
|
input_no_room = FilePipelineInput(transcript_id="ts-123")
|
|
assert input_no_room.room_id is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_pipeline_error_handling_transient(
|
|
file_pipeline_module, mock_file_input, mock_ctx
|
|
):
|
|
"""Transient exception must NOT set error status."""
|
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
TaskName,
|
|
with_error_handling,
|
|
)
|
|
|
|
async def failing_task(input, ctx):
|
|
raise httpx.TimeoutException("timed out")
|
|
|
|
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
|
|
|
with patch(
|
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
|
new_callable=AsyncMock,
|
|
) as mock_set_error:
|
|
with pytest.raises(httpx.TimeoutException):
|
|
await wrapped(mock_file_input, mock_ctx)
|
|
|
|
mock_set_error.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_pipeline_error_handling_hard_fail(
|
|
file_pipeline_module, mock_file_input, mock_ctx
|
|
):
|
|
"""Hard-fail (ValueError) must set error status and raise NonRetryableException."""
|
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
TaskName,
|
|
with_error_handling,
|
|
)
|
|
|
|
async def failing_task(input, ctx):
|
|
raise ValueError("No audio file found")
|
|
|
|
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
|
|
|
with patch(
|
|
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
|
new_callable=AsyncMock,
|
|
) as mock_set_error:
|
|
with pytest.raises(NonRetryableException) as exc_info:
|
|
await wrapped(mock_file_input, mock_ctx)
|
|
|
|
assert "No audio file found" in str(exc_info.value)
|
|
mock_set_error.assert_called_once_with("ts-file-123")
|
|
|
|
|
|
def test_diarize_result_uses_plain_dicts():
|
|
"""DiarizationSegment is a TypedDict (plain dict), not a Pydantic model.
|
|
|
|
The diarize task must serialize segments as plain dicts (not call .model_dump()),
|
|
and assemble_transcript must be able to reconstruct them with DiarizationSegment(**s).
|
|
This was a real bug: 'dict' object has no attribute 'model_dump'.
|
|
"""
|
|
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
|
from reflector.processors.types import DiarizationSegment
|
|
|
|
# DiarizationSegment is a TypedDict — instances are plain dicts
|
|
segments = [
|
|
DiarizationSegment(start=0.0, end=1.5, speaker=0),
|
|
DiarizationSegment(start=1.5, end=3.0, speaker=1),
|
|
]
|
|
assert isinstance(segments[0], dict), "DiarizationSegment should be a plain dict"
|
|
|
|
# DiarizeResult should accept list[dict] directly (no model_dump needed)
|
|
result = DiarizeResult(diarization=segments)
|
|
assert result.diarization is not None
|
|
assert len(result.diarization) == 2
|
|
|
|
# Consumer (assemble_transcript) reconstructs via DiarizationSegment(**s)
|
|
reconstructed = [DiarizationSegment(**s) for s in result.diarization]
|
|
assert reconstructed[0]["start"] == 0.0
|
|
assert reconstructed[0]["speaker"] == 0
|
|
assert reconstructed[1]["end"] == 3.0
|
|
assert reconstructed[1]["speaker"] == 1
|
|
|
|
|
|
def test_diarize_result_handles_none():
|
|
"""DiarizeResult with no diarization data (diarization disabled)."""
|
|
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
|
|
|
result = DiarizeResult(diarization=None)
|
|
assert result.diarization is None
|
|
|
|
result_default = DiarizeResult()
|
|
assert result_default.diarization is None
|
|
|
|
|
|
def test_transcribe_result_words_are_pydantic():
|
|
"""TranscribeResult words come from Pydantic Word.model_dump() — verify roundtrip."""
|
|
from reflector.hatchet.workflows.file_pipeline import TranscribeResult
|
|
from reflector.processors.types import Word
|
|
|
|
words = [
|
|
Word(text="hello", start=0.0, end=0.5),
|
|
Word(text="world", start=0.5, end=1.0),
|
|
]
|
|
# Words are Pydantic models, so model_dump() works
|
|
word_dicts = [w.model_dump() for w in words]
|
|
result = TranscribeResult(words=word_dicts)
|
|
|
|
# Consumer reconstructs via Word(**w)
|
|
reconstructed = [Word(**w) for w in result.words]
|
|
assert reconstructed[0].text == "hello"
|
|
assert reconstructed[1].start == 0.5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_pipeline_on_failure_sets_error_status(
|
|
file_pipeline_module, mock_file_input, mock_ctx
|
|
):
|
|
"""on_workflow_failure sets error status when transcript is processing."""
|
|
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
|
|
|
transcript_processing = MagicMock()
|
|
transcript_processing.status = "processing"
|
|
|
|
with patch(
|
|
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
|
_noop_db_context,
|
|
):
|
|
with patch(
|
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
|
new_callable=AsyncMock,
|
|
return_value=transcript_processing,
|
|
):
|
|
with patch(
|
|
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
|
new_callable=AsyncMock,
|
|
) as mock_set_error:
|
|
await on_workflow_failure(mock_file_input, mock_ctx)
|
|
mock_set_error.assert_called_once_with(mock_file_input.transcript_id)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_pipeline_on_failure_does_not_overwrite_ended(
|
|
file_pipeline_module, mock_file_input, mock_ctx
|
|
):
|
|
"""on_workflow_failure must NOT overwrite 'ended' status."""
|
|
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
|
|
|
transcript_ended = MagicMock()
|
|
transcript_ended.status = "ended"
|
|
|
|
with patch(
|
|
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
|
_noop_db_context,
|
|
):
|
|
with patch(
|
|
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
|
new_callable=AsyncMock,
|
|
return_value=transcript_ended,
|
|
):
|
|
with patch(
|
|
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
|
new_callable=AsyncMock,
|
|
) as mock_set_error:
|
|
await on_workflow_failure(mock_file_input, mock_ctx)
|
|
mock_set_error.assert_not_called()
|