Files
reflector/server/tests/test_hatchet_file_pipeline.py
Juan Diego García 37a1f01850 feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine (#911)
* feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine

* fix: always force reprocessing

* fix: ci tests with live pipelines

* fix: ci tests with live pipelines
2026-03-16 16:07:16 -05:00

234 lines
7.8 KiB
Python

"""
Tests for the FilePipeline Hatchet workflow.
Tests verify:
1. with_error_handling behavior for file pipeline input model
2. on_workflow_failure logic (don't overwrite 'ended' status)
3. Input model validation
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from hatchet_sdk import NonRetryableException
@asynccontextmanager
async def _noop_db_context():
"""Async context manager that yields without touching the DB."""
yield None
@pytest.fixture(scope="module")
def file_pipeline_module():
"""Import file_pipeline with Hatchet client mocked."""
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
):
from reflector.hatchet.workflows import file_pipeline
return file_pipeline
@pytest.fixture
def mock_file_input():
"""Minimal FilePipelineInput for tests."""
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
return FilePipelineInput(
transcript_id="ts-file-123",
room_id="room-456",
)
@pytest.fixture
def mock_ctx():
"""Minimal Context-like object."""
ctx = MagicMock()
ctx.log = MagicMock()
return ctx
def test_file_pipeline_input_model():
"""Test FilePipelineInput validation."""
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
# Valid input with room_id
input_with_room = FilePipelineInput(transcript_id="ts-123", room_id="room-456")
assert input_with_room.transcript_id == "ts-123"
assert input_with_room.room_id == "room-456"
# Valid input without room_id
input_no_room = FilePipelineInput(transcript_id="ts-123")
assert input_no_room.room_id is None
@pytest.mark.asyncio
async def test_file_pipeline_error_handling_transient(
file_pipeline_module, mock_file_input, mock_ctx
):
"""Transient exception must NOT set error status."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise httpx.TimeoutException("timed out")
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(httpx.TimeoutException):
await wrapped(mock_file_input, mock_ctx)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_file_pipeline_error_handling_hard_fail(
file_pipeline_module, mock_file_input, mock_ctx
):
"""Hard-fail (ValueError) must set error status and raise NonRetryableException."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("No audio file found")
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(NonRetryableException) as exc_info:
await wrapped(mock_file_input, mock_ctx)
assert "No audio file found" in str(exc_info.value)
mock_set_error.assert_called_once_with("ts-file-123")
def test_diarize_result_uses_plain_dicts():
"""DiarizationSegment is a TypedDict (plain dict), not a Pydantic model.
The diarize task must serialize segments as plain dicts (not call .model_dump()),
and assemble_transcript must be able to reconstruct them with DiarizationSegment(**s).
This was a real bug: 'dict' object has no attribute 'model_dump'.
"""
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
from reflector.processors.types import DiarizationSegment
# DiarizationSegment is a TypedDict — instances are plain dicts
segments = [
DiarizationSegment(start=0.0, end=1.5, speaker=0),
DiarizationSegment(start=1.5, end=3.0, speaker=1),
]
assert isinstance(segments[0], dict), "DiarizationSegment should be a plain dict"
# DiarizeResult should accept list[dict] directly (no model_dump needed)
result = DiarizeResult(diarization=segments)
assert result.diarization is not None
assert len(result.diarization) == 2
# Consumer (assemble_transcript) reconstructs via DiarizationSegment(**s)
reconstructed = [DiarizationSegment(**s) for s in result.diarization]
assert reconstructed[0]["start"] == 0.0
assert reconstructed[0]["speaker"] == 0
assert reconstructed[1]["end"] == 3.0
assert reconstructed[1]["speaker"] == 1
def test_diarize_result_handles_none():
"""DiarizeResult with no diarization data (diarization disabled)."""
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
result = DiarizeResult(diarization=None)
assert result.diarization is None
result_default = DiarizeResult()
assert result_default.diarization is None
def test_transcribe_result_words_are_pydantic():
"""TranscribeResult words come from Pydantic Word.model_dump() — verify roundtrip."""
from reflector.hatchet.workflows.file_pipeline import TranscribeResult
from reflector.processors.types import Word
words = [
Word(text="hello", start=0.0, end=0.5),
Word(text="world", start=0.5, end=1.0),
]
# Words are Pydantic models, so model_dump() works
word_dicts = [w.model_dump() for w in words]
result = TranscribeResult(words=word_dicts)
# Consumer reconstructs via Word(**w)
reconstructed = [Word(**w) for w in result.words]
assert reconstructed[0].text == "hello"
assert reconstructed[1].start == 0.5
@pytest.mark.asyncio
async def test_file_pipeline_on_failure_sets_error_status(
file_pipeline_module, mock_file_input, mock_ctx
):
"""on_workflow_failure sets error status when transcript is processing."""
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
transcript_processing = MagicMock()
transcript_processing.status = "processing"
with patch(
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_processing,
):
with patch(
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_file_input, mock_ctx)
mock_set_error.assert_called_once_with(mock_file_input.transcript_id)
@pytest.mark.asyncio
async def test_file_pipeline_on_failure_does_not_overwrite_ended(
file_pipeline_module, mock_file_input, mock_ctx
):
"""on_workflow_failure must NOT overwrite 'ended' status."""
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
transcript_ended = MagicMock()
transcript_ended.status = "ended"
with patch(
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_ended,
):
with patch(
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_file_input, mock_ctx)
mock_set_error.assert_not_called()