mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-31 03:16:46 +00:00
feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine (#911)
* feat: migrate file and live post-processing pipelines from Celery to Hatchet workflow engine * fix: always force reprocessing * fix: ci tests with live pipelines * fix: ci tests with live pipelines
This commit is contained in:
committed by
GitHub
parent
72dca7cacc
commit
37a1f01850
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -27,8 +25,6 @@ async def client(app_lifespan):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_process(
|
||||
tmpdir,
|
||||
@@ -39,8 +35,13 @@ async def test_transcript_process(
|
||||
dummy_storage,
|
||||
client,
|
||||
monkeypatch,
|
||||
mock_hatchet_client,
|
||||
):
|
||||
# public mode: this test uses an anonymous client; allow anonymous transcript creation
|
||||
"""Test upload + process dispatch via Hatchet.
|
||||
|
||||
The file pipeline is now dispatched to Hatchet (fire-and-forget),
|
||||
so we verify the workflow was triggered rather than polling for completion.
|
||||
"""
|
||||
monkeypatch.setattr(settings, "PUBLIC_MODE", True)
|
||||
|
||||
# create a transcript
|
||||
@@ -63,51 +64,43 @@ async def test_transcript_process(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish (max 1 minute)
|
||||
timeout_seconds = 60
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await client.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
|
||||
# Verify Hatchet workflow was dispatched (from upload endpoint)
|
||||
from reflector.hatchet.client import HatchetClientManager
|
||||
|
||||
# restart the processing
|
||||
response = await client.post(
|
||||
f"/transcripts/{tid}/process",
|
||||
HatchetClientManager.start_workflow.assert_called_once_with(
|
||||
"FilePipeline",
|
||||
{"transcript_id": tid},
|
||||
additional_metadata={"transcript_id": tid},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# wait for processing to finish (max 1 minute)
|
||||
timeout_seconds = 60
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await client.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
|
||||
# Verify transcript status was set to "uploaded"
|
||||
resp = await client.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "uploaded"
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "Llm Title"
|
||||
# Reset mock for reprocess test
|
||||
HatchetClientManager.start_workflow.reset_mock()
|
||||
|
||||
# check topics and transcript
|
||||
response = await client.get(f"/transcripts/{tid}/topics")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert "Hello world. How are you today?" in response.json()[0]["transcript"]
|
||||
# Clear workflow_run_id so /process endpoint can dispatch again
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
await transcripts_controller.update(transcript, {"workflow_run_id": None})
|
||||
|
||||
# Reprocess via /process endpoint
|
||||
with patch(
|
||||
"reflector.services.transcript_process.task_is_scheduled_or_active",
|
||||
return_value=False,
|
||||
):
|
||||
response = await client.post(f"/transcripts/{tid}/process")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# Verify second Hatchet dispatch (from /process endpoint)
|
||||
HatchetClientManager.start_workflow.assert_called_once()
|
||||
call_kwargs = HatchetClientManager.start_workflow.call_args.kwargs
|
||||
assert call_kwargs["workflow_name"] == "FilePipeline"
|
||||
assert call_kwargs["input_data"]["transcript_id"] == tid
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@@ -150,20 +143,25 @@ async def test_whereby_recording_uses_file_pipeline(monkeypatch, client):
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.services.transcript_process.task_pipeline_file_process"
|
||||
) as mock_file_pipeline,
|
||||
"reflector.services.transcript_process.task_is_scheduled_or_active",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"reflector.services.transcript_process.HatchetClientManager"
|
||||
) as mock_hatchet,
|
||||
):
|
||||
mock_hatchet.start_workflow = AsyncMock(return_value="test-workflow-id")
|
||||
|
||||
response = await client.post(f"/transcripts/{transcript.id}/process")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# Whereby recordings should use file pipeline, not Hatchet
|
||||
mock_file_pipeline.delay.assert_called_once_with(transcript_id=transcript.id)
|
||||
mock_hatchet.start_workflow.assert_not_called()
|
||||
# Whereby recordings should use Hatchet FilePipeline
|
||||
mock_hatchet.start_workflow.assert_called_once()
|
||||
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
|
||||
assert call_kwargs["workflow_name"] == "FilePipeline"
|
||||
assert call_kwargs["input_data"]["transcript_id"] == transcript.id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@@ -224,8 +222,9 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
|
||||
|
||||
with (
|
||||
patch(
|
||||
"reflector.services.transcript_process.task_pipeline_file_process"
|
||||
) as mock_file_pipeline,
|
||||
"reflector.services.transcript_process.task_is_scheduled_or_active",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"reflector.services.transcript_process.HatchetClientManager"
|
||||
) as mock_hatchet,
|
||||
@@ -237,7 +236,7 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# Daily.co multitrack recordings should use Hatchet workflow
|
||||
# Daily.co multitrack recordings should use Hatchet DiarizationPipeline
|
||||
mock_hatchet.start_workflow.assert_called_once()
|
||||
call_kwargs = mock_hatchet.start_workflow.call_args.kwargs
|
||||
assert call_kwargs["workflow_name"] == "DiarizationPipeline"
|
||||
@@ -246,7 +245,6 @@ async def test_dailyco_recording_uses_multitrack_pipeline(monkeypatch, client):
|
||||
assert call_kwargs["input_data"]["tracks"] == [
|
||||
{"s3_key": k} for k in track_keys
|
||||
]
|
||||
mock_file_pipeline.delay.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
|
||||
Reference in New Issue
Block a user