feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine (#911)

* feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine

* fix: always force reprocessing

* fix: ci tests with live pipelines

* fix: ci tests with live pipelines
This commit is contained in:
Juan Diego García
2026-03-16 16:07:16 -05:00
committed by GitHub
parent 72dca7cacc
commit 37a1f01850
22 changed files with 2140 additions and 353 deletions

View File

@@ -1,6 +1,6 @@
import os
from contextlib import asynccontextmanager
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -538,18 +538,59 @@ def fake_mp3_upload():
@pytest.fixture(autouse=True)
def reset_hatchet_client():
"""Reset HatchetClientManager singleton before and after each test.
def mock_hatchet_client():
"""Mock HatchetClientManager for all tests.
This ensures test isolation - each test starts with a fresh client state.
The fixture is autouse=True so it applies to all tests automatically.
Prevents tests from connecting to a real Hatchet server. The dummy token
in [tool.pytest_env] prevents the import-time ValueError, but the SDK
would still try to connect when get_client() is called. This fixture
mocks get_client to return a MagicMock and start_workflow to return a
dummy workflow ID.
"""
from reflector.hatchet.client import HatchetClientManager
# Reset before test
HatchetClientManager.reset()
yield
# Reset after test to clean up
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with (
patch.object(
HatchetClientManager,
"get_client",
return_value=mock_client,
),
patch.object(
HatchetClientManager,
"start_workflow",
new_callable=AsyncMock,
return_value="mock-workflow-id",
),
patch.object(
HatchetClientManager,
"get_workflow_run_status",
new_callable=AsyncMock,
return_value=None,
),
patch.object(
HatchetClientManager,
"can_replay",
new_callable=AsyncMock,
return_value=False,
),
patch.object(
HatchetClientManager,
"cancel_workflow",
new_callable=AsyncMock,
),
patch.object(
HatchetClientManager,
"replay_workflow",
new_callable=AsyncMock,
),
):
yield mock_client
HatchetClientManager.reset()

View File

@@ -37,18 +37,3 @@ async def test_hatchet_client_can_replay_handles_exception():
# Should return False on error (workflow might be gone)
assert can_replay is False
def test_hatchet_client_raises_without_token():
"""Test that get_client raises ValueError without token.
Useful: Catches if someone removes the token validation,
which would cause cryptic errors later.
"""
from reflector.hatchet.client import HatchetClientManager
with patch("reflector.hatchet.client.settings") as mock_settings:
mock_settings.HATCHET_CLIENT_TOKEN = None
with pytest.raises(ValueError, match="HATCHET_CLIENT_TOKEN must be set"):
HatchetClientManager.get_client()

View File

@@ -0,0 +1,233 @@
"""
Tests for the FilePipeline Hatchet workflow.
Tests verify:
1. with_error_handling behavior for file pipeline input model
2. on_workflow_failure logic (don't overwrite 'ended' status)
3. Input model validation
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from hatchet_sdk import NonRetryableException
@asynccontextmanager
async def _noop_db_context():
"""Async context manager that yields without touching the DB."""
yield None
@pytest.fixture(scope="module")
def file_pipeline_module():
"""Import file_pipeline with Hatchet client mocked."""
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
):
from reflector.hatchet.workflows import file_pipeline
return file_pipeline
@pytest.fixture
def mock_file_input():
"""Minimal FilePipelineInput for tests."""
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
return FilePipelineInput(
transcript_id="ts-file-123",
room_id="room-456",
)
@pytest.fixture
def mock_ctx():
"""Minimal Context-like object."""
ctx = MagicMock()
ctx.log = MagicMock()
return ctx
def test_file_pipeline_input_model():
"""Test FilePipelineInput validation."""
from reflector.hatchet.workflows.file_pipeline import FilePipelineInput
# Valid input with room_id
input_with_room = FilePipelineInput(transcript_id="ts-123", room_id="room-456")
assert input_with_room.transcript_id == "ts-123"
assert input_with_room.room_id == "room-456"
# Valid input without room_id
input_no_room = FilePipelineInput(transcript_id="ts-123")
assert input_no_room.room_id is None
@pytest.mark.asyncio
async def test_file_pipeline_error_handling_transient(
file_pipeline_module, mock_file_input, mock_ctx
):
"""Transient exception must NOT set error status."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise httpx.TimeoutException("timed out")
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(httpx.TimeoutException):
await wrapped(mock_file_input, mock_ctx)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_file_pipeline_error_handling_hard_fail(
file_pipeline_module, mock_file_input, mock_ctx
):
"""Hard-fail (ValueError) must set error status and raise NonRetryableException."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("No audio file found")
wrapped = with_error_handling(TaskName.EXTRACT_AUDIO)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(NonRetryableException) as exc_info:
await wrapped(mock_file_input, mock_ctx)
assert "No audio file found" in str(exc_info.value)
mock_set_error.assert_called_once_with("ts-file-123")
def test_diarize_result_uses_plain_dicts():
"""DiarizationSegment is a TypedDict (plain dict), not a Pydantic model.
The diarize task must serialize segments as plain dicts (not call .model_dump()),
and assemble_transcript must be able to reconstruct them with DiarizationSegment(**s).
This was a real bug: 'dict' object has no attribute 'model_dump'.
"""
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
from reflector.processors.types import DiarizationSegment
# DiarizationSegment is a TypedDict — instances are plain dicts
segments = [
DiarizationSegment(start=0.0, end=1.5, speaker=0),
DiarizationSegment(start=1.5, end=3.0, speaker=1),
]
assert isinstance(segments[0], dict), "DiarizationSegment should be a plain dict"
# DiarizeResult should accept list[dict] directly (no model_dump needed)
result = DiarizeResult(diarization=segments)
assert result.diarization is not None
assert len(result.diarization) == 2
# Consumer (assemble_transcript) reconstructs via DiarizationSegment(**s)
reconstructed = [DiarizationSegment(**s) for s in result.diarization]
assert reconstructed[0]["start"] == 0.0
assert reconstructed[0]["speaker"] == 0
assert reconstructed[1]["end"] == 3.0
assert reconstructed[1]["speaker"] == 1
def test_diarize_result_handles_none():
"""DiarizeResult with no diarization data (diarization disabled)."""
from reflector.hatchet.workflows.file_pipeline import DiarizeResult
result = DiarizeResult(diarization=None)
assert result.diarization is None
result_default = DiarizeResult()
assert result_default.diarization is None
def test_transcribe_result_words_are_pydantic():
"""TranscribeResult words come from Pydantic Word.model_dump() — verify roundtrip."""
from reflector.hatchet.workflows.file_pipeline import TranscribeResult
from reflector.processors.types import Word
words = [
Word(text="hello", start=0.0, end=0.5),
Word(text="world", start=0.5, end=1.0),
]
# Words are Pydantic models, so model_dump() works
word_dicts = [w.model_dump() for w in words]
result = TranscribeResult(words=word_dicts)
# Consumer reconstructs via Word(**w)
reconstructed = [Word(**w) for w in result.words]
assert reconstructed[0].text == "hello"
assert reconstructed[1].start == 0.5
@pytest.mark.asyncio
async def test_file_pipeline_on_failure_sets_error_status(
file_pipeline_module, mock_file_input, mock_ctx
):
"""on_workflow_failure sets error status when transcript is processing."""
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
transcript_processing = MagicMock()
transcript_processing.status = "processing"
with patch(
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_processing,
):
with patch(
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_file_input, mock_ctx)
mock_set_error.assert_called_once_with(mock_file_input.transcript_id)
@pytest.mark.asyncio
async def test_file_pipeline_on_failure_does_not_overwrite_ended(
file_pipeline_module, mock_file_input, mock_ctx
):
"""on_workflow_failure must NOT overwrite 'ended' status."""
from reflector.hatchet.workflows.file_pipeline import on_workflow_failure
transcript_ended = MagicMock()
transcript_ended.status = "ended"
with patch(
"reflector.hatchet.workflows.file_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_ended,
):
with patch(
"reflector.hatchet.workflows.file_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_file_input, mock_ctx)
mock_set_error.assert_not_called()

View File

@@ -0,0 +1,218 @@
"""
Tests for the LivePostProcessingPipeline Hatchet workflow.
Tests verify:
1. with_error_handling behavior for live post pipeline input model
2. on_workflow_failure logic (don't overwrite 'ended' status)
3. Input model validation
4. pipeline_post() now triggers Hatchet instead of Celery chord
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from hatchet_sdk import NonRetryableException
@asynccontextmanager
async def _noop_db_context():
"""Async context manager that yields without touching the DB."""
yield None
@pytest.fixture(scope="module")
def live_pipeline_module():
"""Import live_post_pipeline with Hatchet client mocked."""
mock_client = MagicMock()
mock_client.workflow.return_value = MagicMock()
with patch(
"reflector.hatchet.client.HatchetClientManager.get_client",
return_value=mock_client,
):
from reflector.hatchet.workflows import live_post_pipeline
return live_post_pipeline
@pytest.fixture
def mock_live_input():
"""Minimal LivePostPipelineInput for tests."""
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
return LivePostPipelineInput(
transcript_id="ts-live-789",
room_id="room-abc",
)
@pytest.fixture
def mock_ctx():
"""Minimal Context-like object."""
ctx = MagicMock()
ctx.log = MagicMock()
return ctx
def test_live_post_pipeline_input_model():
"""Test LivePostPipelineInput validation."""
from reflector.hatchet.workflows.live_post_pipeline import LivePostPipelineInput
# Valid input with room_id
input_with_room = LivePostPipelineInput(transcript_id="ts-123", room_id="room-456")
assert input_with_room.transcript_id == "ts-123"
assert input_with_room.room_id == "room-456"
# Valid input without room_id
input_no_room = LivePostPipelineInput(transcript_id="ts-123")
assert input_no_room.room_id is None
@pytest.mark.asyncio
async def test_live_pipeline_error_handling_transient(
live_pipeline_module, mock_live_input, mock_ctx
):
"""Transient exception must NOT set error status."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise httpx.TimeoutException("timed out")
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(httpx.TimeoutException):
await wrapped(mock_live_input, mock_ctx)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_live_pipeline_error_handling_hard_fail(
live_pipeline_module, mock_live_input, mock_ctx
):
"""Hard-fail must set error status and raise NonRetryableException."""
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
TaskName,
with_error_handling,
)
async def failing_task(input, ctx):
raise ValueError("Transcript not found")
wrapped = with_error_handling(TaskName.WAVEFORM)(failing_task)
with patch(
"reflector.hatchet.workflows.daily_multitrack_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
with pytest.raises(NonRetryableException) as exc_info:
await wrapped(mock_live_input, mock_ctx)
assert "Transcript not found" in str(exc_info.value)
mock_set_error.assert_called_once_with("ts-live-789")
@pytest.mark.asyncio
async def test_live_pipeline_on_failure_sets_error_status(
live_pipeline_module, mock_live_input, mock_ctx
):
"""on_workflow_failure sets error status when transcript is processing."""
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
transcript_processing = MagicMock()
transcript_processing.status = "processing"
with patch(
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_processing,
):
with patch(
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_live_input, mock_ctx)
mock_set_error.assert_called_once_with(mock_live_input.transcript_id)
@pytest.mark.asyncio
async def test_live_pipeline_on_failure_does_not_overwrite_ended(
live_pipeline_module, mock_live_input, mock_ctx
):
"""on_workflow_failure must NOT overwrite 'ended' status."""
from reflector.hatchet.workflows.live_post_pipeline import on_workflow_failure
transcript_ended = MagicMock()
transcript_ended.status = "ended"
with patch(
"reflector.hatchet.workflows.live_post_pipeline.fresh_db_connection",
_noop_db_context,
):
with patch(
"reflector.db.transcripts.transcripts_controller.get_by_id",
new_callable=AsyncMock,
return_value=transcript_ended,
):
with patch(
"reflector.hatchet.workflows.live_post_pipeline.set_workflow_error_status",
new_callable=AsyncMock,
) as mock_set_error:
await on_workflow_failure(mock_live_input, mock_ctx)
mock_set_error.assert_not_called()
@pytest.mark.asyncio
async def test_pipeline_post_triggers_hatchet():
"""pipeline_post() should trigger Hatchet LivePostProcessingPipeline workflow."""
with patch(
"reflector.hatchet.client.HatchetClientManager.start_workflow",
new_callable=AsyncMock,
return_value="workflow-run-id",
) as mock_start:
from reflector.pipelines.main_live_pipeline import pipeline_post
await pipeline_post(transcript_id="ts-test-123", room_id="room-test")
mock_start.assert_called_once_with(
"LivePostProcessingPipeline",
{
"transcript_id": "ts-test-123",
"room_id": "room-test",
},
additional_metadata={"transcript_id": "ts-test-123"},
)
@pytest.mark.asyncio
async def test_pipeline_post_triggers_hatchet_without_room_id():
"""pipeline_post() should handle None room_id."""
with patch(
"reflector.hatchet.client.HatchetClientManager.start_workflow",
new_callable=AsyncMock,
return_value="workflow-run-id",
) as mock_start:
from reflector.pipelines.main_live_pipeline import pipeline_post
await pipeline_post(transcript_id="ts-test-456")
mock_start.assert_called_once_with(
"LivePostProcessingPipeline",
{
"transcript_id": "ts-test-456",
"room_id": None,
},
additional_metadata={"transcript_id": "ts-test-456"},
)

View File

@@ -0,0 +1,90 @@
"""
Tests verifying Celery-to-Hatchet trigger migration.
Ensures that:
1. process_recording triggers FilePipeline via Hatchet (not Celery)
2. transcript_record_upload triggers FilePipeline via Hatchet (not Celery)
3. Old Celery task references are no longer in active call sites
"""
def test_process_recording_does_not_import_celery_file_task():
"""Verify process.py no longer imports task_pipeline_file_process."""
import inspect
from reflector.worker import process
source = inspect.getsource(process)
# Should not contain the old Celery task import
assert "task_pipeline_file_process" not in source
def test_transcripts_upload_does_not_import_celery_file_task():
"""Verify transcripts_upload.py no longer imports task_pipeline_file_process."""
import inspect
from reflector.views import transcripts_upload
source = inspect.getsource(transcripts_upload)
# Should not contain the old Celery task import
assert "task_pipeline_file_process" not in source
def test_transcripts_upload_imports_hatchet():
"""Verify transcripts_upload.py imports HatchetClientManager."""
import inspect
from reflector.views import transcripts_upload
source = inspect.getsource(transcripts_upload)
assert "HatchetClientManager" in source
def test_pipeline_post_is_async():
"""Verify pipeline_post is now async (Hatchet trigger)."""
import asyncio
from reflector.pipelines.main_live_pipeline import pipeline_post
assert asyncio.iscoroutinefunction(pipeline_post)
def test_transcript_process_service_does_not_import_celery_file_task():
"""Verify transcript_process.py service no longer imports task_pipeline_file_process."""
import inspect
from reflector.services import transcript_process
source = inspect.getsource(transcript_process)
assert "task_pipeline_file_process" not in source
def test_transcript_process_service_dispatch_uses_hatchet():
"""Verify dispatch_transcript_processing uses HatchetClientManager for file processing."""
import inspect
from reflector.services import transcript_process
source = inspect.getsource(transcript_process.dispatch_transcript_processing)
assert "HatchetClientManager" in source
assert "FilePipeline" in source
def test_new_task_names_exist():
"""Verify new TaskName constants were added for file and live pipelines."""
from reflector.hatchet.constants import TaskName
# File pipeline tasks
assert TaskName.EXTRACT_AUDIO == "extract_audio"
assert TaskName.UPLOAD_AUDIO == "upload_audio"
assert TaskName.TRANSCRIBE == "transcribe"
assert TaskName.DIARIZE == "diarize"
assert TaskName.ASSEMBLE_TRANSCRIPT == "assemble_transcript"
assert TaskName.GENERATE_SUMMARIES == "generate_summaries"
# Live post-processing pipeline tasks
assert TaskName.WAVEFORM == "waveform"
assert TaskName.CONVERT_MP3 == "convert_mp3"
assert TaskName.UPLOAD_MP3 == "upload_mp3"
assert TaskName.REMOVE_UPLOAD == "remove_upload"
assert TaskName.FINAL_SUMMARIES == "final_summaries"

View File

@@ -1,5 +1,3 @@
import asyncio
import time
from unittest.mock import AsyncMock, patch
import pytest
@@ -27,8 +25,6 @@ async def client(app_lifespan):
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
@@ -39,8 +35,13 @@ async def test_transcript_process(
dummy_storage,
client,
monkeypatch,
mock_hatchet_client,
):
# public mode: this test uses an anonymous client; allow anonymous transcript creation
"""Test upload + process dispatch via Hatchet.
The file pipeline is now dispatched to Hatchet (fire-and-forget),
so we verify the workflow was triggered rather than polling for completion.
"""
monkeypatch.setattr(settings, "PUBLIC_MODE", True)
# create a transcript
@@ -63,51 +64,43 @@ async def test_transcript_process(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish (max 1 minute)
timeout_seconds = 60
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# Verify Hatchet workflow was dispatched (from upload endpoint)
from reflector.hatchet.client import HatchetClientManager
# restart the processing
response = await client.post(
f"/transcripts/{tid}/process",
HatchetClientManager.start_workflow.assert_called_once_with(
"FilePipeline",
{"transcript_id": tid},
additional_metadata={"transcript_id": tid},
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
await asyncio.sleep(2)
# wait for processing to finish (max 1 minute)
timeout_seconds = 60
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
# Verify transcript status was set to "uploaded"
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "uploaded"
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "Llm Title"
# Reset mock for reprocess test
HatchetClientManager.start_workflow.reset_mock()
# check topics and transcript
response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "Hello world. How are you today?" in response.json()[0]["transcript"]
# Clear workflow_run_id so /process endpoint can dispatch again
from reflector.db.transcripts import transcripts_controller
transcript = await transcripts_controller.get_by_id(tid)
await transcripts_controller.update(transcript, {"workflow_run_id": None})
# Reprocess via /process endpoint
with patch(
"reflector.services.transcript_process.task_is_scheduled_or_active",
return_value=False,
):
response = await client.post(f"/transcripts/{tid}/process")
assert response.status_code == 200
assert response.json()["status"] == "ok"
# Verify second Hatchet dispatch (from /process endpoint)
HatchetClientManager.start_workflow.assert_called_once()
call_kwargs = HatchetClientManager.start_workflow.call_args.kwargs
assert call_kwargs["workflow_name"] == "FilePipeline"
assert call_kwargs["input_data"]["transcript_id"] == tid
@pytest.mark.usefixtures("setup_database")
@@ -150,20 +143,25 @@ async def test_whereby_recording_uses_file_pipeline(monkeypatch, client):
with (
patch(
"reflector.services.transcript_process.task_pipeline_file_process"
) as mock_file_pipeline,
"reflector.services.transcript_process.task_is_scheduled_or_active",
return_value=False,
),
patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet,
):
mock_hatchet.start_workflow = AsyncMock(return_value="test-workflow-id")
response = await client.post(f"/transcripts/{transcript.id}/process")
assert response.status_code == 200
assert response.json()["status"] == "ok"
# Whereby recordings should use file pipeline, not Hatchet
mock_file_pipeline.delay.assert_called_once_with(transcript_id=transcript.id)
mock_hatchet.start_workflow.assert_not_called()
# Whereby recordings should use Hatchet FilePipeline
mock_hatchet.start_workflow.assert_called_once()
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
assert call_kwargs["workflow_name"] == "FilePipeline"
assert call_kwargs["input_data"]["transcript_id"] == transcript.id
@pytest.mark.usefixtures("setup_database")
@@ -224,8 +222,9 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
with (
patch(
"reflector.services.transcript_process.task_pipeline_file_process"
) as mock_file_pipeline,
"reflector.services.transcript_process.task_is_scheduled_or_active",
return_value=False,
),
patch(
"reflector.services.transcript_process.HatchetClientManager"
) as mock_hatchet,
@@ -237,7 +236,7 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
assert response.status_code == 200
assert response.json()["status"] == "ok"
# Daily.co multitrack recordings should use Hatchet workflow
# Daily.co multitrack recordings should use Hatchet DiarizationPipeline
mock_hatchet.start_workflow.assert_called_once()
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
assert call_kwargs["workflow_name"] == "DiarizationPipeline"
@@ -246,7 +245,6 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
assert call_kwargs["input_data"]["tracks"] == [
{"s3_key": k} for k in track_keys
]
mock_file_pipeline.delay.assert_not_called()
@pytest.mark.usefixtures("setup_database")

View File

@@ -2,6 +2,10 @@
# FIXME test status of transcript
# FIXME test websocket connection after RTC is finished still send the full events
# FIXME try with locked session, RTC should not work
# TODO: add integration tests for post-processing (LivePostPipeline) with a real
# Hatchet instance. These tests currently only cover the live pipeline.
# Post-processing events (WAVEFORM, FINAL_*, DURATION, STATUS=ended, mp3)
# are now dispatched via Hatchet and tested in test_hatchet_live_post_pipeline.py.
import asyncio
import json
@@ -49,7 +53,7 @@ class ThreadedUvicorn:
@pytest.fixture
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
def appserver(tmpdir, setup_database):
import threading
from reflector.app import app
@@ -119,8 +123,6 @@ def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker)
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
@@ -134,6 +136,7 @@ async def test_transcript_rtc_and_websocket(
appserver,
client,
monkeypatch,
mock_hatchet_client,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -208,35 +211,30 @@ async def test_transcript_rtc_and_websocket(
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
await stream_client.stop()
# wait the processing to finish
timeout = 120
# Wait for live pipeline to flush (it dispatches post-processing to Hatchet)
timeout = 30
while True:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
if resp.json()["status"] in ("processing", "ended", "error"):
break
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended":
raise TimeoutError("Transcript processing failed")
raise TimeoutError("Timeout waiting for live pipeline to finish")
# stop websocket task
websocket_task.cancel()
# check events
# check live pipeline events
assert len(events) > 0
from pprint import pprint
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
# Live pipeline produces TRANSCRIPT and TOPIC events during RTC
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.")
@@ -249,50 +247,18 @@ async def test_transcript_rtc_and_websocket(
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "Llm Title"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await client.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
assert len(waveform_resp.json()["data"]) >= 250
# check status order
# Live pipeline status progression
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert "recording" in statuses
assert "processing" in statuses
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
assert "DURATION" in eventnames
# check that audio/mp3 is available
audio_resp = await client.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
# Post-processing (WAVEFORM, FINAL_*, DURATION, mp3, STATUS=ended) is now
# dispatched to Hatchet via LivePostPipeline — not tested here.
# See test_hatchet_live_post_pipeline.py for post-processing tests.
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,
@@ -306,6 +272,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
appserver,
client,
monkeypatch,
mock_hatchet_client,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -382,42 +349,34 @@ async def test_transcript_rtc_and_websocket_and_fr(
# instead of waiting a long time, we just send a STOP
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await stream_client.stop()
# wait the processing to finish
timeout = 120
# Wait for live pipeline to flush
timeout = 30
while True:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
if resp.json()["status"] in ("processing", "ended", "error"):
break
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended":
raise TimeoutError("Transcript processing failed")
await asyncio.sleep(2)
raise TimeoutError("Timeout waiting for live pipeline to finish")
# stop websocket task
websocket_task.cancel()
# check events
# check live pipeline events
assert len(events) > 0
from pprint import pprint
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
# Live pipeline produces TRANSCRIPT with translation
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.")
@@ -430,23 +389,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "Llm Title"
# check status order
# Live pipeline status progression
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert "recording" in statuses
assert "processing" in statuses
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# Post-processing (FINAL_*, STATUS=ended) is now dispatched to Hatchet
# via LivePostPipeline — not tested here.

View File

@@ -1,12 +1,7 @@
import asyncio
import time
import pytest
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_upload_file(
tmpdir,
@@ -17,6 +12,7 @@ async def test_transcript_upload_file(
dummy_storage,
client,
monkeypatch,
mock_hatchet_client,
):
from reflector.settings import settings
@@ -43,27 +39,16 @@ async def test_transcript_upload_file(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait the processing to finish (max 1 minute)
timeout_seconds = 60
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
return pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
# Verify Hatchet workflow was dispatched for file processing
from reflector.hatchet.client import HatchetClientManager
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "Llm Title"
HatchetClientManager.start_workflow.assert_called_once_with(
"FilePipeline",
{"transcript_id": tid},
additional_metadata={"transcript_id": tid},
)
# check topics and transcript
response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "Hello world. How are you today?" in response.json()[0]["transcript"]
# Verify transcript status was updated to "uploaded"
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "uploaded"