mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-04-24 22:25:19 +00:00
feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine (#911)
* feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine * fix: always force reprocessing * fix: ci tests with live pipelines * fix: ci tests with live pipelines
This commit is contained in:
committed by
GitHub
parent
72dca7cacc
commit
37a1f01850
233
server/tests/test_hatchet_file_pipeline.py
Normal file
233
server/tests/test_hatchet_file_pipeline.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Tests for the FilePipeline Hatchet workflow.
|
||||
|
||||
Tests verify:
|
||||
1. with_error_handling behavior for file pipeline input model
|
||||
2. on_workflow_failure logic (don't overwrite 'ended' status)
|
||||
3. Input model validation
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from hatchet_sdk import NonRetryableException
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_db_context():
|
||||
"""Async context manager that yields without touching the DB."""
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def file_pipeline_module():
|
||||
"""Import file_pipeline with Hatchet client mocked."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.workflow.return_value = MagicMock()
|
||||
with patch(
|
||||
"reflector.hatchet.client.HatchetClientManager.get_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
from reflector.hatchet.workflows import file_pipeline
|
||||
|
||||
return file_pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_input():
|
||||
"""Minimal FilePipelineInput for tests."""
|
||||
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
||||
|
||||
return FilePipelineInput(
|
||||
transcript_id="ts-file-123",
|
||||
room_id="room-456",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx():
|
||||
"""Minimal Context-like object."""
|
||||
ctx = MagicMock()
|
||||
ctx.log = MagicMock()
|
||||
return ctx
|
||||
|
||||
|
||||
def test_file_pipeline_input_model():
|
||||
"""Test FilePipelineInput validation."""
|
||||
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
|
||||
|
||||
# Valid input with room_id
|
||||
input_with_room = FilePipelineInput(transcript_id="ts-123", room_id="room-456")
|
||||
assert input_with_room.transcript_id == "ts-123"
|
||||
assert input_with_room.room_id == "room-456"
|
||||
|
||||
# Valid input without room_id
|
||||
input_no_room = FilePipelineInput(transcript_id="ts-123")
|
||||
assert input_no_room.room_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_error_handling_transient(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""Transient exception must NOT set error status."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise httpx.TimeoutException("timed out")
|
||||
|
||||
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
await wrapped(mock_file_input, mock_ctx)
|
||||
|
||||
mock_set_error.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_error_handling_hard_fail(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""Hard-fail (ValueError) must set error status and raise NonRetryableException."""
|
||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||
TaskName,
|
||||
with_error_handling,
|
||||
)
|
||||
|
||||
async def failing_task(input, ctx):
|
||||
raise ValueError("No audio file found")
|
||||
|
||||
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
with pytest.raises(NonRetryableException) as exc_info:
|
||||
await wrapped(mock_file_input, mock_ctx)
|
||||
|
||||
assert "No audio file found" in str(exc_info.value)
|
||||
mock_set_error.assert_called_once_with("ts-file-123")
|
||||
|
||||
|
||||
def test_diarize_result_uses_plain_dicts():
|
||||
"""DiarizationSegment is a TypedDict (plain dict), not a Pydantic model.
|
||||
|
||||
The diarize task must serialize segments as plain dicts (not call .model_dump()),
|
||||
and assemble_transcript must be able to reconstruct them with DiarizationSegment(**s).
|
||||
This was a real bug: 'dict' object has no attribute 'model_dump'.
|
||||
"""
|
||||
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
# DiarizationSegment is a TypedDict — instances are plain dicts
|
||||
segments = [
|
||||
DiarizationSegment(start=0.0, end=1.5, speaker=0),
|
||||
DiarizationSegment(start=1.5, end=3.0, speaker=1),
|
||||
]
|
||||
assert isinstance(segments[0], dict), "DiarizationSegment should be a plain dict"
|
||||
|
||||
# DiarizeResult should accept list[dict] directly (no model_dump needed)
|
||||
result = DiarizeResult(diarization=segments)
|
||||
assert result.diarization is not None
|
||||
assert len(result.diarization) == 2
|
||||
|
||||
# Consumer (assemble_transcript) reconstructs via DiarizationSegment(**s)
|
||||
reconstructed = [DiarizationSegment(**s) for s in result.diarization]
|
||||
assert reconstructed[0]["start"] == 0.0
|
||||
assert reconstructed[0]["speaker"] == 0
|
||||
assert reconstructed[1]["end"] == 3.0
|
||||
assert reconstructed[1]["speaker"] == 1
|
||||
|
||||
|
||||
def test_diarize_result_handles_none():
|
||||
"""DiarizeResult with no diarization data (diarization disabled)."""
|
||||
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
|
||||
|
||||
result = DiarizeResult(diarization=None)
|
||||
assert result.diarization is None
|
||||
|
||||
result_default = DiarizeResult()
|
||||
assert result_default.diarization is None
|
||||
|
||||
|
||||
def test_transcribe_result_words_are_pydantic():
|
||||
"""TranscribeResult words come from Pydantic Word.model_dump() — verify roundtrip."""
|
||||
from reflector.hatchet.workflows.file_pipeline import TranscribeResult
|
||||
from reflector.processors.types import Word
|
||||
|
||||
words = [
|
||||
Word(text="hello", start=0.0, end=0.5),
|
||||
Word(text="world", start=0.5, end=1.0),
|
||||
]
|
||||
# Words are Pydantic models, so model_dump() works
|
||||
word_dicts = [w.model_dump() for w in words]
|
||||
result = TranscribeResult(words=word_dicts)
|
||||
|
||||
# Consumer reconstructs via Word(**w)
|
||||
reconstructed = [Word(**w) for w in result.words]
|
||||
assert reconstructed[0].text == "hello"
|
||||
assert reconstructed[1].start == 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_on_failure_sets_error_status(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""on_workflow_failure sets error status when transcript is processing."""
|
||||
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
||||
|
||||
transcript_processing = MagicMock()
|
||||
transcript_processing.status = "processing"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_processing,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_file_input, mock_ctx)
|
||||
mock_set_error.assert_called_once_with(mock_file_input.transcript_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_pipeline_on_failure_does_not_overwrite_ended(
|
||||
file_pipeline_module, mock_file_input, mock_ctx
|
||||
):
|
||||
"""on_workflow_failure must NOT overwrite 'ended' status."""
|
||||
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
|
||||
|
||||
transcript_ended = MagicMock()
|
||||
transcript_ended.status = "ended"
|
||||
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
|
||||
_noop_db_context,
|
||||
):
|
||||
with patch(
|
||||
"reflector.db.transcripts.transcripts_controller.get_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript_ended,
|
||||
):
|
||||
with patch(
|
||||
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_error:
|
||||
await on_workflow_failure(mock_file_input, mock_ctx)
|
||||
mock_set_error.assert_not_called()
|
||||
Reference in New Issue
Block a user