feat: migrate to taskiq

This commit is contained in:
2025-09-24 19:02:45 -06:00
parent b7f8e8ef8d
commit d86dc59bf2
35 changed files with 1210 additions and 667 deletions

View File

@@ -1,7 +1,6 @@
import asyncio
import os
import sys
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest
@@ -322,26 +321,60 @@ async def dummy_storage():
yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
# from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
# from sqlalchemy.orm import sessionmaker
@pytest.fixture(scope="session")
def celery_config():
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
# @pytest.fixture()
# async def db_connection(sqla_engine):
# connection = await sqla_engine.connect()
# try:
# yield connection
# finally:
# await connection.close()
@pytest.fixture(scope="session")
def celery_includes():
return [
"reflector.pipelines.main_live_pipeline",
"reflector.pipelines.main_file_pipeline",
]
# @pytest.fixture()
# async def db_session_maker(db_connection):
# Session = async_sessionmaker(
# db_connection,
# expire_on_commit=False,
# class_=AsyncSession,
# )
# yield Session
# @pytest.fixture()
# async def db_session(db_session_maker, db_connection):
# """
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
# after the test completes.
# """
# session = db_session_maker(
# bind=db_connection,
# join_transaction_mode="create_savepoint",
# )
# try:
# yield session
# finally:
# await session.close()
# @pytest.fixture(autouse=True)
# async def ensure_db_session_in_app(db_connection, db_session_maker):
# async def mock_get_session():
# session = db_session_maker(
# bind=db_connection, join_transaction_mode="create_savepoint"
# )
# try:
# yield session
# finally:
# await session.close()
# with patch("reflector.db._get_session", side_effect=mock_get_session):
# yield
@pytest.fixture(autouse=True)
@@ -372,6 +405,18 @@ def fake_mp3_upload():
yield
@pytest.fixture
async def taskiq_broker():
from reflector.worker.app import taskiq_broker
await taskiq_broker.startup()
try:
yield taskiq_broker
finally:
await taskiq_broker.shutdown()
@pytest.fixture
async def fake_transcript_with_topics(tmpdir, client, db_session):
import shutil

View File

@@ -130,15 +130,15 @@ async def test_sync_all_ics_calendars(db_session):
ics_enabled=False,
)
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
with patch("reflector.worker.ics_sync.sync_room_ics.kiq") as mock_kiq:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
for room in ics_enabled_rooms:
if room and _should_sync(room):
sync_room_ics.delay(room.id)
await sync_room_ics.kiq(room.id)
assert mock_delay.call_count == 2
called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
assert mock_kiq.call_count == 2
called_room_ids = [call.args[0] for call in mock_kiq.call_args_list]
assert room1.id in called_room_ids
assert room2.id in called_room_ids
assert room3.id not in called_room_ids
@@ -210,15 +210,15 @@ async def test_sync_respects_fetch_interval(db_session):
{"ics_last_sync": now - timedelta(seconds=100)},
)
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
with patch("reflector.worker.ics_sync.sync_room_ics.kiq") as mock_kiq:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
for room in ics_enabled_rooms:
if room and _should_sync(room):
sync_room_ics.delay(room.id)
await sync_room_ics.kiq(room.id)
assert mock_delay.call_count == 1
assert mock_delay.call_args[0][0] == room2.id
assert mock_kiq.call_count == 1
assert mock_kiq.call_args[0][0] == room2.id
@pytest.mark.asyncio

View File

@@ -1,9 +1,11 @@
import asyncio
import time
import os
import pytest
from httpx import ASGITransport, AsyncClient
# Set environment for TaskIQ to use InMemoryBroker
os.environ["ENVIRONMENT"] = "pytest"
@pytest.fixture
async def app_lifespan():
@@ -23,8 +25,16 @@ async def client(app_lifespan):
)
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.fixture
async def taskiq_broker():
from reflector.worker.app import taskiq_broker
# Broker is already initialized as InMemoryBroker due to ENVIRONMENT=pytest
await taskiq_broker.startup()
yield taskiq_broker
await taskiq_broker.shutdown()
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
@@ -34,7 +44,10 @@ async def test_transcript_process(
dummy_file_diarization,
dummy_storage,
client,
taskiq_broker,
db_session,
):
print("IN TEST", db_session)
# create a transcript
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
@@ -55,18 +68,14 @@ async def test_transcript_process(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish (max 1 minute)
timeout_seconds = 60
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# Wait for all tasks to complete since we're using InMemoryBroker
await taskiq_broker.wait_all()
# Ensure it's finished ok
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
print(resp.json())
assert resp.json()["status"] in ("ended", "error")
# restart the processing
response = await client.post(
@@ -74,20 +83,15 @@ async def test_transcript_process(
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
await asyncio.sleep(2)
# wait for processing to finish (max 1 minute)
timeout_seconds = 60
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
# Wait for all tasks to complete since we're using InMemoryBroker
await taskiq_broker.wait_all()
# Ensure it's finished ok
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
print(resp.json())
assert resp.json()["status"] in ("ended", "error")
# check the transcript is ended
transcript = resp.json()

View File

@@ -49,7 +49,7 @@ class ThreadedUvicorn:
@pytest.fixture
def appserver(tmpdir, database, celery_session_app, celery_session_worker):
def appserver(tmpdir, database):
import threading
from reflector.app import app
@@ -111,8 +111,6 @@ def appserver(tmpdir, database, celery_session_app, celery_session_worker):
settings.DATA_DIR = DATA_DIR
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
@@ -275,8 +273,6 @@ async def test_transcript_rtc_and_websocket(
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,

View File

@@ -4,8 +4,6 @@ import time
import pytest
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_upload_file(
tmpdir,