mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
feat: migrate to taskiq
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -322,26 +321,60 @@ async def dummy_storage():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_enable_logging():
|
||||
return True
|
||||
# from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
# from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_config():
|
||||
with NamedTemporaryFile() as f:
|
||||
yield {
|
||||
"broker_url": "memory://",
|
||||
"result_backend": f"db+sqlite:///{f.name}",
|
||||
}
|
||||
# @pytest.fixture()
|
||||
# async def db_connection(sqla_engine):
|
||||
# connection = await sqla_engine.connect()
|
||||
# try:
|
||||
# yield connection
|
||||
# finally:
|
||||
# await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_includes():
|
||||
return [
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
"reflector.pipelines.main_file_pipeline",
|
||||
]
|
||||
# @pytest.fixture()
|
||||
# async def db_session_maker(db_connection):
|
||||
# Session = async_sessionmaker(
|
||||
# db_connection,
|
||||
# expire_on_commit=False,
|
||||
# class_=AsyncSession,
|
||||
# )
|
||||
# yield Session
|
||||
|
||||
|
||||
# @pytest.fixture()
|
||||
# async def db_session(db_session_maker, db_connection):
|
||||
# """
|
||||
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
|
||||
# after the test completes.
|
||||
# """
|
||||
# session = db_session_maker(
|
||||
# bind=db_connection,
|
||||
# join_transaction_mode="create_savepoint",
|
||||
# )
|
||||
|
||||
# try:
|
||||
# yield session
|
||||
# finally:
|
||||
# await session.close()
|
||||
|
||||
|
||||
# @pytest.fixture(autouse=True)
|
||||
# async def ensure_db_session_in_app(db_connection, db_session_maker):
|
||||
# async def mock_get_session():
|
||||
# session = db_session_maker(
|
||||
# bind=db_connection, join_transaction_mode="create_savepoint"
|
||||
# )
|
||||
|
||||
# try:
|
||||
# yield session
|
||||
# finally:
|
||||
# await session.close()
|
||||
|
||||
# with patch("reflector.db._get_session", side_effect=mock_get_session):
|
||||
# yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -372,6 +405,18 @@ def fake_mp3_upload():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def taskiq_broker():
|
||||
from reflector.worker.app import taskiq_broker
|
||||
|
||||
await taskiq_broker.startup()
|
||||
|
||||
try:
|
||||
yield taskiq_broker
|
||||
finally:
|
||||
await taskiq_broker.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fake_transcript_with_topics(tmpdir, client, db_session):
|
||||
import shutil
|
||||
|
||||
@@ -130,15 +130,15 @@ async def test_sync_all_ics_calendars(db_session):
|
||||
ics_enabled=False,
|
||||
)
|
||||
|
||||
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
||||
with patch("reflector.worker.ics_sync.sync_room_ics.kiq") as mock_kiq:
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
|
||||
|
||||
for room in ics_enabled_rooms:
|
||||
if room and _should_sync(room):
|
||||
sync_room_ics.delay(room.id)
|
||||
await sync_room_ics.kiq(room.id)
|
||||
|
||||
assert mock_delay.call_count == 2
|
||||
called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
|
||||
assert mock_kiq.call_count == 2
|
||||
called_room_ids = [call.args[0] for call in mock_kiq.call_args_list]
|
||||
assert room1.id in called_room_ids
|
||||
assert room2.id in called_room_ids
|
||||
assert room3.id not in called_room_ids
|
||||
@@ -210,15 +210,15 @@ async def test_sync_respects_fetch_interval(db_session):
|
||||
{"ics_last_sync": now - timedelta(seconds=100)},
|
||||
)
|
||||
|
||||
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
||||
with patch("reflector.worker.ics_sync.sync_room_ics.kiq") as mock_kiq:
|
||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
|
||||
|
||||
for room in ics_enabled_rooms:
|
||||
if room and _should_sync(room):
|
||||
sync_room_ics.delay(room.id)
|
||||
await sync_room_ics.kiq(room.id)
|
||||
|
||||
assert mock_delay.call_count == 1
|
||||
assert mock_delay.call_args[0][0] == room2.id
|
||||
assert mock_kiq.call_count == 1
|
||||
assert mock_kiq.call_args[0][0] == room2.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Set environment for TaskIQ to use InMemoryBroker
|
||||
os.environ["ENVIRONMENT"] = "pytest"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_lifespan():
|
||||
@@ -23,8 +25,16 @@ async def client(app_lifespan):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.fixture
|
||||
async def taskiq_broker():
|
||||
from reflector.worker.app import taskiq_broker
|
||||
|
||||
# Broker is already initialized as InMemoryBroker due to ENVIRONMENT=pytest
|
||||
await taskiq_broker.startup()
|
||||
yield taskiq_broker
|
||||
await taskiq_broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_process(
|
||||
tmpdir,
|
||||
@@ -34,7 +44,10 @@ async def test_transcript_process(
|
||||
dummy_file_diarization,
|
||||
dummy_storage,
|
||||
client,
|
||||
taskiq_broker,
|
||||
db_session,
|
||||
):
|
||||
print("IN TEST", db_session)
|
||||
# create a transcript
|
||||
response = await client.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
@@ -55,18 +68,14 @@ async def test_transcript_process(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish (max 1 minute)
|
||||
timeout_seconds = 60
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await client.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
|
||||
# Wait for all tasks to complete since we're using InMemoryBroker
|
||||
await taskiq_broker.wait_all()
|
||||
|
||||
# Ensure it's finished ok
|
||||
resp = await client.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
print(resp.json())
|
||||
assert resp.json()["status"] in ("ended", "error")
|
||||
|
||||
# restart the processing
|
||||
response = await client.post(
|
||||
@@ -74,20 +83,15 @@ async def test_transcript_process(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# wait for processing to finish (max 1 minute)
|
||||
timeout_seconds = 60
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await client.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
|
||||
# Wait for all tasks to complete since we're using InMemoryBroker
|
||||
await taskiq_broker.wait_all()
|
||||
|
||||
# Ensure it's finished ok
|
||||
resp = await client.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
print(resp.json())
|
||||
assert resp.json()["status"] in ("ended", "error")
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
|
||||
@@ -49,7 +49,7 @@ class ThreadedUvicorn:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def appserver(tmpdir, database, celery_session_app, celery_session_worker):
|
||||
def appserver(tmpdir, database):
|
||||
import threading
|
||||
|
||||
from reflector.app import app
|
||||
@@ -111,8 +111,6 @@ def appserver(tmpdir, database, celery_session_app, celery_session_worker):
|
||||
settings.DATA_DIR = DATA_DIR
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket(
|
||||
tmpdir,
|
||||
@@ -275,8 +273,6 @@ async def test_transcript_rtc_and_websocket(
|
||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket_and_fr(
|
||||
tmpdir,
|
||||
|
||||
@@ -4,8 +4,6 @@ import time
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_upload_file(
|
||||
tmpdir,
|
||||
|
||||
Reference in New Issue
Block a user