test: update test fixtures to use @with_session decorator

- Replace manual session management in test fixtures with @with_session decorator
- Simplify async test fixtures by removing explicit session handling
- Update dependencies in pyproject.toml and uv.lock
This commit is contained in:
2025-09-23 12:09:26 -06:00
parent 8ad1270229
commit 27b3b9cdee
14 changed files with 1776 additions and 1837 deletions

View File

@@ -46,6 +46,7 @@ dev = [
"black>=24.1.1",
"stamina>=23.1.0",
"pyinstrument>=4.6.1",
"pytest-async-sqlalchemy>=0.2.0",
]
tests = [
"pytest-cov>=4.1.0",
@@ -117,7 +118,9 @@ DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/ref
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"]
asyncio_mode = "auto"
asyncio_debug = true
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
markers = [
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
]

View File

@@ -3,17 +3,15 @@ import os
import sys
import pytest
import pytest_asyncio
from sqlalchemy.pool import NullPool
@pytest.fixture(scope="session")
def event_loop():
"""Session-scoped event loop."""
"""Create an instance of the default event loop for the test session."""
if sys.platform.startswith("win") and sys.version_info[:2] >= (3, 8):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
loop = asyncio.get_event_loop_policy().new_event_loop()
loop = asyncio.new_event_loop()
yield loop
loop.close()
@@ -53,65 +51,42 @@ def docker_ip():
return "127.0.0.1"
# Only register docker_services dependent fixtures if docker plugin is available
try:
import pytest_docker # noqa: F401
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
try:
import psycopg2
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
docker_services.wait_until_responsive(
timeout=30.0, pause=0.1, check=is_responsive
)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
except ImportError:
# Docker plugin not available, provide a dummy fixture
@pytest.fixture(scope="session")
def postgres_service(docker_ip):
"""Dummy postgres service when docker plugin is not available"""
return {
"host": docker_ip,
"port": 15432, # Default test postgres port
"database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"database": "reflector_test",
"user": "test_user",
"password": "test_password",
}
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database(postgres_service):
"""Setup database and run migrations"""
from sqlalchemy.ext.asyncio import create_async_engine
from reflector.db import Base
# Build database URL from connection params
@pytest.fixture(scope="session")
def _database_url(postgres_service):
"""Provide database URL for pytest-async-sqlalchemy."""
db_config = postgres_service
DATABASE_URL = (
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
@@ -123,70 +98,15 @@ async def setup_database(postgres_service):
settings.DATABASE_URL = DATABASE_URL
engine = create_async_engine(
DATABASE_URL,
echo=False,
poolclass=NullPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
yield
await engine.dispose()
return DATABASE_URL
@pytest_asyncio.fixture
async def session(setup_database):
"""Provide a transactional database session for tests"""
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
@pytest.fixture(scope="session")
def init_database():
"""Provide database initialization for pytest-async-sqlalchemy."""
from reflector.db import Base
from reflector.settings import settings
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
poolclass=NullPool,
)
async_session_maker = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
# Start a savepoint instead of a transaction to handle nested commits
await session.begin()
# Override commit to use flush instead in tests
original_commit = session.commit
async def flush_instead_of_commit():
await session.flush()
session.commit = flush_instead_of_commit
try:
yield session
await session.rollback()
except Exception:
await session.rollback()
raise
finally:
session.commit = original_commit # Restore original commit
await session.close()
# Properly dispose of the engine to close all connections
await engine.dispose()
return Base.metadata.create_all
@pytest.fixture

View File

@@ -8,7 +8,7 @@ from reflector.services.ics_sync import ICSSyncService
@pytest.mark.asyncio
async def test_attendee_parsing_bug(session):
async def test_attendee_parsing_bug(db_session):
"""
Test that reproduces the attendee parsing bug where a string with comma-separated
emails gets parsed as individual characters instead of separate email addresses.
@@ -17,7 +17,7 @@ async def test_attendee_parsing_bug(session):
instead of properly parsed email addresses.
"""
room = await rooms_controller.add(
session,
db_session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -31,7 +31,7 @@ async def test_attendee_parsing_bug(session):
ics_url="http://test.com/test.ics",
ics_enabled=True,
)
await session.flush()
await db_session.flush()
from datetime import datetime, timedelta, timezone
@@ -59,7 +59,7 @@ async def test_attendee_parsing_bug(session):
@asynccontextmanager
async def mock_session_context():
yield session
yield db_session
class MockSessionMaker:
def __call__(self):

View File

@@ -11,7 +11,7 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_calendar_event_create(session):
async def test_calendar_event_create(db_db_session):
"""Test creating a calendar event."""
# Create a room first
room = await rooms_controller.add(
@@ -54,7 +54,7 @@ async def test_calendar_event_create(session):
@pytest.mark.asyncio
async def test_calendar_event_get_by_room(session):
async def test_calendar_event_get_by_room(db_db_session):
"""Test getting calendar events for a room."""
# Create room
room = await rooms_controller.add(
@@ -95,7 +95,7 @@ async def test_calendar_event_get_by_room(session):
@pytest.mark.asyncio
async def test_calendar_event_get_upcoming(session):
async def test_calendar_event_get_upcoming(db_db_session):
"""Test getting upcoming events within time window."""
# Create room
room = await rooms_controller.add(
@@ -177,7 +177,7 @@ async def test_calendar_event_get_upcoming(session):
@pytest.mark.asyncio
async def test_calendar_event_get_upcoming_includes_currently_happening(session):
async def test_calendar_event_get_upcoming_includes_currently_happening(db_db_session):
"""Test that get_upcoming includes currently happening events but excludes ended events."""
# Create room
room = await rooms_controller.add(
@@ -238,7 +238,7 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(session)
@pytest.mark.asyncio
async def test_calendar_event_upsert(session):
async def test_calendar_event_upsert(db_db_session):
"""Test upserting (create/update) calendar events."""
# Create room
room = await rooms_controller.add(
@@ -285,7 +285,7 @@ async def test_calendar_event_upsert(session):
@pytest.mark.asyncio
async def test_calendar_event_soft_delete(session):
async def test_calendar_event_soft_delete(db_db_session):
"""Test soft deleting events no longer in calendar."""
# Create room
room = await rooms_controller.add(
@@ -338,7 +338,7 @@ async def test_calendar_event_soft_delete(session):
@pytest.mark.asyncio
async def test_calendar_event_past_events_not_deleted(session):
async def test_calendar_event_past_events_not_deleted(db_db_session):
"""Test that past events are not soft deleted."""
# Create room
room = await rooms_controller.add(
@@ -393,7 +393,7 @@ async def test_calendar_event_past_events_not_deleted(session):
@pytest.mark.asyncio
async def test_calendar_event_with_raw_ics_data(session):
async def test_calendar_event_with_raw_ics_data(db_db_session):
"""Test storing raw ICS data with calendar event."""
# Create room
room = await rooms_controller.add(

View File

@@ -15,19 +15,19 @@ from reflector.worker.cleanup import cleanup_old_public_data
@pytest.mark.asyncio
async def test_cleanup_old_public_data_skips_when_not_public(session):
async def test_cleanup_old_public_data_skips_when_not_public(db_session):
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = False
result = await cleanup_old_public_data(session)
result = await cleanup_old_public_data(db_session)
# Should return early without doing anything
assert result is None
@pytest.mark.asyncio
async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session):
async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_session):
"""Test that old anonymous transcripts are deleted."""
# Create old and new anonymous transcripts
old_date = datetime.now(timezone.utc) - timedelta(days=8)
@@ -35,23 +35,23 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
# Create old anonymous transcript (should be deleted)
old_transcript = await transcripts_controller.add(
session,
db_session,
name="Old Anonymous Transcript",
source_kind=SourceKind.FILE,
user_id=None, # Anonymous
)
# Manually update created_at to be old
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_transcript.id)
.values(created_at=old_date)
)
await session.commit()
await db_session.commit()
# Create new anonymous transcript (should NOT be deleted)
new_transcript = await transcripts_controller.add(
session,
db_session,
name="New Anonymous Transcript",
source_kind=SourceKind.FILE,
user_id=None, # Anonymous
@@ -59,17 +59,17 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
# Create old transcript with user (should NOT be deleted)
old_user_transcript = await transcripts_controller.add(
session,
db_session,
name="Old User Transcript",
source_kind=SourceKind.FILE,
user_id="user-123",
)
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_user_transcript.id)
.values(created_at=old_date)
)
await session.commit()
await db_session.commit()
# Mock settings for public mode
with patch("reflector.worker.cleanup.settings") as mock_settings:
@@ -81,7 +81,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
mock_delete.return_value = None
# Run cleanup with test session
await cleanup_old_public_data(session)
await cleanup_old_public_data(db_session)
# Verify only old anonymous transcript was deleted
assert mock_delete.call_count == 1
@@ -92,27 +92,27 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(session
@pytest.mark.asyncio
async def test_cleanup_deletes_associated_meeting_and_recording(session):
async def test_cleanup_deletes_associated_meeting_and_recording(db_session):
"""Test that cleanup deletes associated meetings and recordings."""
old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create an old transcript with both meeting and recording
old_transcript = await transcripts_controller.add(
session,
db_session,
name="Old Transcript with Meeting and Recording",
source_kind=SourceKind.FILE,
user_id=None,
)
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_transcript.id)
.values(created_at=old_date)
)
await session.commit()
await db_session.commit()
# Create associated meeting directly
meeting_id = "test-meeting-id"
await session.execute(
await db_session.execute(
insert(MeetingModel).values(
id=meeting_id,
room_id=None,
@@ -132,7 +132,7 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
# Create associated recording directly
recording_id = "test-recording-id"
await session.execute(
await db_session.execute(
insert(RecordingModel).values(
id=recording_id,
meeting_id=meeting_id,
@@ -142,15 +142,15 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
created_at=old_date,
)
)
await session.commit()
await db_session.commit()
# Update transcript with meeting_id and recording_id
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_transcript.id)
.values(meeting_id=meeting_id, recording_id=recording_id)
)
await session.commit()
await db_session.commit()
# Mock settings
with patch("reflector.worker.cleanup.settings") as mock_settings:
@@ -162,24 +162,24 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
mock_storage.return_value.delete_file = AsyncMock()
# Run cleanup with test session
await cleanup_old_public_data(session)
await cleanup_old_public_data(db_session)
# Verify transcript was deleted
result = await session.execute(
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == old_transcript.id)
)
transcript = result.scalar_one_or_none()
assert transcript is None
# Verify meeting was deleted
result = await session.execute(
result = await db_session.execute(
select(MeetingModel).where(MeetingModel.id == meeting_id)
)
meeting = result.scalar_one_or_none()
assert meeting is None
# Verify recording was deleted
result = await session.execute(
result = await db_session.execute(
select(RecordingModel).where(RecordingModel.id == recording_id)
)
recording = result.scalar_one_or_none()
@@ -187,35 +187,35 @@ async def test_cleanup_deletes_associated_meeting_and_recording(session):
@pytest.mark.asyncio
async def test_cleanup_handles_errors_gracefully(session):
async def test_cleanup_handles_errors_gracefully(db_session):
"""Test that cleanup continues even if individual deletions fail."""
old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create multiple old transcripts
transcript1 = await transcripts_controller.add(
session,
db_session,
name="Transcript 1",
source_kind=SourceKind.FILE,
user_id=None,
)
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript1.id)
.values(created_at=old_date)
)
transcript2 = await transcripts_controller.add(
session,
db_session,
name="Transcript 2",
source_kind=SourceKind.FILE,
user_id=None,
)
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript2.id)
.values(created_at=old_date)
)
await session.commit()
await db_session.commit()
with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = True
@@ -226,34 +226,34 @@ async def test_cleanup_handles_errors_gracefully(session):
mock_delete.side_effect = [Exception("Delete failed"), None]
# Run cleanup with test session - should not raise exception
await cleanup_old_public_data(session)
await cleanup_old_public_data(db_session)
# Both transcripts should have been attempted to delete
assert mock_delete.call_count == 2
@pytest.mark.asyncio
async def test_meeting_consent_cascade_delete(session):
async def test_meeting_consent_cascade_delete(db_session):
"""Test that meeting_consent entries are cascade deleted with meetings."""
old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create an old transcript
transcript = await transcripts_controller.add(
session,
db_session,
name="Transcript with Meeting",
source_kind=SourceKind.FILE,
user_id=None,
)
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
.values(created_at=old_date)
)
await session.commit()
await db_session.commit()
# Create a meeting directly
meeting_id = "test-meeting-consent"
await session.execute(
await db_session.execute(
insert(MeetingModel).values(
id=meeting_id,
room_id=None,
@@ -270,18 +270,18 @@ async def test_meeting_consent_cascade_delete(session):
recording_trigger="automatic",
)
)
await session.commit()
await db_session.commit()
# Update transcript with meeting_id
await session.execute(
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
.values(meeting_id=meeting_id)
)
await session.commit()
await db_session.commit()
# Create meeting_consent entries
await session.execute(
await db_session.execute(
insert(MeetingConsentModel).values(
id="consent-1",
meeting_id=meeting_id,
@@ -290,7 +290,7 @@ async def test_meeting_consent_cascade_delete(session):
consent_timestamp=old_date,
)
)
await session.execute(
await db_session.execute(
insert(MeetingConsentModel).values(
id="consent-2",
meeting_id=meeting_id,
@@ -299,26 +299,26 @@ async def test_meeting_consent_cascade_delete(session):
consent_timestamp=old_date,
)
)
await session.commit()
await db_session.commit()
# Verify consent entries exist
result = await session.execute(
result = await db_session.execute(
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id)
)
consents = result.scalars().all()
assert len(consents) == 2
# Delete the transcript and meeting
await session.execute(
await db_session.execute(
TranscriptModel.__table__.delete().where(TranscriptModel.id == transcript.id)
)
await session.execute(
await db_session.execute(
MeetingModel.__table__.delete().where(MeetingModel.id == meeting_id)
)
await session.commit()
await db_session.commit()
# Verify consent entries were cascade deleted
result = await session.execute(
result = await db_session.execute(
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id)
)
consents = result.scalars().all()

View File

@@ -14,7 +14,7 @@ from reflector.worker.ics_sync import (
@pytest.mark.asyncio
async def test_sync_room_ics_task(session):
async def test_sync_room_ics_task(db_db_session):
room = await rooms_controller.add(
session,
name="task-test-room",
@@ -30,7 +30,7 @@ async def test_sync_room_ics_task(session):
ics_url="https://calendar.example.com/task.ics",
ics_enabled=True,
)
await session.flush()
await db_session.flush()
cal = Calendar()
event = Event()
@@ -49,7 +49,7 @@ async def test_sync_room_ics_task(session):
@asynccontextmanager
async def mock_session_context():
yield session
yield db_session
class MockSessionMaker:
def __call__(self):
@@ -74,7 +74,7 @@ async def test_sync_room_ics_task(session):
@pytest.mark.asyncio
async def test_sync_room_ics_disabled(session):
async def test_sync_room_ics_disabled(db_db_session):
room = await rooms_controller.add(
session,
name="disabled-room",
@@ -97,7 +97,7 @@ async def test_sync_room_ics_disabled(session):
@pytest.mark.asyncio
async def test_sync_all_ics_calendars(session):
async def test_sync_all_ics_calendars(db_db_session):
room1 = await rooms_controller.add(
session,
name="sync-all-1",
@@ -176,7 +176,7 @@ async def test_should_sync_logic():
@pytest.mark.asyncio
async def test_sync_respects_fetch_interval(session):
async def test_sync_respects_fetch_interval(db_db_session):
now = datetime.now(timezone.utc)
room1 = await rooms_controller.add(
@@ -237,7 +237,7 @@ async def test_sync_respects_fetch_interval(session):
@pytest.mark.asyncio
async def test_sync_handles_errors_gracefully(session):
async def test_sync_handles_errors_gracefully(db_db_session):
room = await rooms_controller.add(
session,
name="error-task-room",

View File

@@ -134,7 +134,7 @@ async def test_ics_fetch_service_extract_room_events():
@pytest.mark.asyncio
async def test_ics_sync_service_sync_room_calendar(session):
async def test_ics_sync_service_sync_room_calendar(db_db_session):
# Create room
room = await rooms_controller.add(
session,
@@ -151,7 +151,7 @@ async def test_ics_sync_service_sync_room_calendar(session):
ics_url="https://calendar.example.com/test.ics",
ics_enabled=True,
)
await session.flush()
await db_session.flush()
# Mock ICS content
cal = Calendar()
@@ -172,7 +172,7 @@ async def test_ics_sync_service_sync_room_calendar(session):
@asynccontextmanager
async def mock_session_context():
yield session
yield db_session
class MockSessionMaker:
def __call__(self):
@@ -280,7 +280,7 @@ async def test_ics_sync_service_skip_disabled():
@pytest.mark.asyncio
async def test_ics_sync_service_error_handling(session):
async def test_ics_sync_service_error_handling(db_db_session):
# Create room
room = await rooms_controller.add(
session,
@@ -297,13 +297,13 @@ async def test_ics_sync_service_error_handling(session):
ics_url="https://calendar.example.com/error.ics",
ics_enabled=True,
)
await session.flush()
await db_session.flush()
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session_context():
yield session
yield db_session
class MockSessionMaker:
def __call__(self):

View File

@@ -10,7 +10,7 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_multiple_active_meetings_per_room(session):
async def test_multiple_active_meetings_per_room(db_db_session):
"""Test that multiple active meetings can exist for the same room."""
# Create a room
room = await rooms_controller.add(
@@ -65,7 +65,7 @@ async def test_multiple_active_meetings_per_room(session):
@pytest.mark.asyncio
async def test_get_active_by_calendar_event(session):
async def test_get_active_by_calendar_event(db_db_session):
"""Test getting active meeting by calendar event ID."""
# Create a room
room = await rooms_controller.add(
@@ -120,7 +120,7 @@ async def test_get_active_by_calendar_event(session):
@pytest.mark.asyncio
async def test_calendar_meeting_deactivates_after_scheduled_end(session):
async def test_calendar_meeting_deactivates_after_scheduled_end(db_db_session):
"""Test that unused calendar meetings deactivate after scheduled end time."""
# Create a room
room = await rooms_controller.add(

View File

@@ -10,7 +10,7 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_room_create_with_ics_fields(session):
async def test_room_create_with_ics_fields(db_db_session):
"""Test creating a room with ICS calendar fields."""
room = await rooms_controller.add(
session,
@@ -41,7 +41,7 @@ async def test_room_create_with_ics_fields(session):
@pytest.mark.asyncio
async def test_room_update_ics_configuration(session):
async def test_room_update_ics_configuration(db_db_session):
"""Test updating room ICS configuration."""
# Create room without ICS
room = await rooms_controller.add(
@@ -80,7 +80,7 @@ async def test_room_update_ics_configuration(session):
@pytest.mark.asyncio
async def test_room_ics_sync_metadata(session):
async def test_room_ics_sync_metadata(db_db_session):
"""Test updating room ICS sync metadata."""
room = await rooms_controller.add(
session,
@@ -114,7 +114,7 @@ async def test_room_ics_sync_metadata(session):
@pytest.mark.asyncio
async def test_room_get_with_ics_fields(session):
async def test_room_get_with_ics_fields(db_db_session):
"""Test retrieving room with ICS fields."""
# Create room
created_room = await rooms_controller.add(
@@ -150,7 +150,7 @@ async def test_room_get_with_ics_fields(session):
@pytest.mark.asyncio
async def test_room_list_with_ics_enabled_filter(session):
async def test_room_list_with_ics_enabled_filter(db_db_session):
"""Test listing rooms filtered by ICS enabled status."""
# Create rooms with and without ICS
room1 = await rooms_controller.add(
@@ -211,7 +211,7 @@ async def test_room_list_with_ics_enabled_filter(session):
@pytest.mark.asyncio
async def test_room_default_ics_values(session):
async def test_room_default_ics_values(db_db_session):
"""Test that ICS fields have correct default values."""
room = await rooms_controller.add(
session,

View File

@@ -17,25 +17,25 @@ from reflector.db.transcripts import SourceKind
@pytest.mark.asyncio
async def test_search_postgresql_only(session):
async def test_search_postgresql_only(db_session):
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert results == []
assert total == 0
params_empty = SearchParameters(query_text=None)
results_empty, total_empty = await search_controller.search_transcripts(
session, params_empty
db_session, params_empty
)
assert isinstance(results_empty, list)
assert isinstance(total_empty, int)
@pytest.mark.asyncio
async def test_search_with_empty_query(session):
async def test_search_with_empty_query(db_session):
"""Test that empty query returns all transcripts."""
params = SearchParameters(query_text=None)
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert isinstance(results, list)
assert isinstance(total, int)
@@ -45,12 +45,12 @@ async def test_search_with_empty_query(session):
@pytest.mark.asyncio
async def test_empty_transcript_title_only_match(session):
async def test_empty_transcript_title_only_match(db_session):
"""Test that transcripts with title-only matches return empty snippets."""
test_id = "test-empty-9b3f2a8d"
try:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
@@ -77,11 +77,11 @@ async def test_empty_transcript_title_only_match(session):
"user_id": "test-user-1",
}
await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
await db_session.execute(insert(TranscriptModel).values(**test_data))
await db_session.commit()
params = SearchParameters(query_text="empty", user_id="test-user-1")
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert total >= 1
found = next((r for r in results if r.id == test_id), None)
@@ -90,19 +90,19 @@ async def test_empty_transcript_title_only_match(session):
assert found.total_match_count == 0
finally:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await session.commit()
await db_session.commit()
@pytest.mark.asyncio
async def test_search_with_long_summary(session):
async def test_search_with_long_summary(db_session):
"""Test that long_summary content is searchable."""
test_id = "test-long-summary-8a9f3c2d"
try:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
@@ -132,11 +132,11 @@ Basic meeting content without special keywords.""",
"user_id": "test-user-2",
}
await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
await db_session.execute(insert(TranscriptModel).values(**test_data))
await db_session.commit()
params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
@@ -148,18 +148,18 @@ Basic meeting content without special keywords.""",
assert "quantum computing" in test_result.search_snippets[0].lower()
finally:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await session.commit()
await db_session.commit()
@pytest.mark.asyncio
async def test_postgresql_search_with_data(session):
async def test_postgresql_search_with_data(db_session):
test_id = "test-search-e2e-7f3a9b2c"
try:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
@@ -198,17 +198,17 @@ We need to implement PostgreSQL tsvector for better performance.""",
"user_id": "test-user-3",
}
await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
await db_session.execute(insert(TranscriptModel).values(**test_data))
await db_session.commit()
params = SearchParameters(query_text="planning", user_id="test-user-3")
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
params = SearchParameters(query_text="tsvector", user_id="test-user-3")
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
@@ -216,7 +216,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text="engineering planning", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
@@ -231,7 +231,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text="tsvector OR nosuchword", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
@@ -239,16 +239,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text='"full-text search"', user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(session, params)
results, total = await search_controller.search_transcripts(db_session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase"
finally:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await session.commit()
await db_session.commit()
@pytest.fixture
@@ -314,20 +314,20 @@ class TestSearchControllerFilters:
"""Test SearchController functionality with various filters."""
@pytest.mark.asyncio
async def test_search_with_source_kind_filter(self, session):
async def test_search_with_source_kind_filter(self, db_session):
"""Test search filtering by source_kind."""
controller = SearchController()
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
# This should not fail, even if no results are found
results, total = await controller.search_transcripts(session, params)
results, total = await controller.search_transcripts(db_session, params)
assert isinstance(results, list)
assert isinstance(total, int)
assert total >= 0
@pytest.mark.asyncio
async def test_search_with_single_room_id(self, session):
async def test_search_with_single_room_id(self, db_session):
"""Test search filtering by single room ID (currently supported)."""
controller = SearchController()
params = SearchParameters(
@@ -336,7 +336,7 @@ class TestSearchControllerFilters:
)
# This should not fail, even if no results are found
results, total = await controller.search_transcripts(session, params)
results, total = await controller.search_transcripts(db_session, params)
assert isinstance(results, list)
assert isinstance(total, int)
@@ -344,14 +344,14 @@ class TestSearchControllerFilters:
@pytest.mark.asyncio
async def test_search_result_includes_available_fields(
self, session, mock_db_result
self, db_session, mock_db_result
):
"""Test that search results include available fields like source_kind."""
# Test that the search method works and returns SearchResult objects
controller = SearchController()
params = SearchParameters(query_text="test")
results, total = await controller.search_transcripts(session, params)
results, total = await controller.search_transcripts(db_session, params)
assert isinstance(results, list)
assert isinstance(total, int)

View File

@@ -11,13 +11,13 @@ from reflector.db.search import SearchParameters, search_controller
@pytest.mark.asyncio
async def test_long_summary_snippet_prioritization(session):
async def test_long_summary_snippet_prioritization(db_db_session):
"""Test that snippets from long_summary are prioritized over webvtt content."""
test_id = "test-snippet-priority-3f9a2b8c"
try:
# Clean up any existing test data
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
@@ -57,7 +57,7 @@ We need to consider various implementation approaches.""",
"user_id": "test-user-priority",
}
await session.execute(insert(TranscriptModel).values(**test_data))
await db_session.execute(insert(TranscriptModel).values(**test_data))
# Search for "robotics" which appears in both long_summary and webvtt
params = SearchParameters(query_text="robotics", user_id="test-user-priority")
@@ -86,19 +86,19 @@ We need to consider various implementation approaches.""",
), f"Snippet should contain search term: {snippet}"
finally:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await session.commit()
await db_session.commit()
@pytest.mark.asyncio
async def test_long_summary_only_search(session):
async def test_long_summary_only_search(db_db_session):
"""Test searching for content that only exists in long_summary."""
test_id = "test-long-only-8b3c9f2a"
try:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
@@ -135,7 +135,7 @@ Discussion of timeline and deliverables.""",
"user_id": "test-user-long",
}
await session.execute(insert(TranscriptModel).values(**test_data))
await db_session.execute(insert(TranscriptModel).values(**test_data))
# Search for terms only in long_summary
params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long")
@@ -160,7 +160,7 @@ Discussion of timeline and deliverables.""",
assert found2, "Should find transcript by specific long_summary phrase"
finally:
await session.execute(
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await session.commit()
await db_session.commit()

View File

@@ -10,11 +10,11 @@ from reflector.db.transcripts import SourceKind, transcripts_controller
@pytest.mark.asyncio
async def test_recording_deleted_with_transcript(session):
async def test_recording_deleted_with_transcript(db_db_session):
"""Test that a recording is deleted when its associated transcript is deleted."""
# First create a room and meeting to satisfy foreign key constraints
room_id = "test-room"
await session.execute(
await db_session.execute(
insert(RoomModel).values(
id=room_id,
name="test-room",
@@ -32,7 +32,7 @@ async def test_recording_deleted_with_transcript(session):
)
meeting_id = "test-meeting"
await session.execute(
await db_session.execute(
insert(MeetingModel).values(
id=meeting_id,
room_id=room_id,
@@ -49,7 +49,7 @@ async def test_recording_deleted_with_transcript(session):
recording_trigger="automatic",
)
)
await session.commit()
await db_session.commit()
# Now create a recording
recording = await recordings_controller.create(

View File

@@ -30,7 +30,7 @@ class TestWebVTTAutoUpdate:
)
try:
result = await session.execute(
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
row = result.scalar_one_or_none()
@@ -40,7 +40,7 @@ class TestWebVTTAutoUpdate:
finally:
await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_updated_on_upsert_topic(self, session):
async def test_webvtt_updated_on_upsert_topic(self, db_db_session):
"""WebVTT should update when upserting topics via upsert_topic method."""
# Using global transcripts_controller
@@ -64,7 +64,7 @@ class TestWebVTTAutoUpdate:
await transcripts_controller.upsert_topic(session, transcript, topic)
result = await session.execute(
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
row = result.scalar_one_or_none()
@@ -80,7 +80,7 @@ class TestWebVTTAutoUpdate:
finally:
await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_updated_on_direct_topics_update(self, session):
async def test_webvtt_updated_on_direct_topics_update(self, db_db_session):
"""WebVTT should update when updating topics field directly."""
# Using global transcripts_controller
@@ -109,7 +109,7 @@ class TestWebVTTAutoUpdate:
)
# Fetch from DB
result = await session.execute(
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
row = result.scalar_one_or_none()
@@ -124,7 +124,9 @@ class TestWebVTTAutoUpdate:
finally:
await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_updated_manually_with_handle_topics_update(self, session):
async def test_webvtt_updated_manually_with_handle_topics_update(
self, db_db_session
):
"""Test that _handle_topics_update works when called manually."""
# Using global transcripts_controller
@@ -153,7 +155,7 @@ class TestWebVTTAutoUpdate:
await transcripts_controller.update(session, transcript, values)
# Fetch from DB
result = await session.execute(
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
row = result.scalar_one_or_none()
@@ -169,7 +171,7 @@ class TestWebVTTAutoUpdate:
finally:
await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_update_with_non_sequential_topics_fails(self, session):
async def test_webvtt_update_with_non_sequential_topics_fails(self, db_db_session):
"""Test that non-sequential topics raise assertion error."""
# Using global transcripts_controller
@@ -202,7 +204,7 @@ class TestWebVTTAutoUpdate:
finally:
await transcripts_controller.remove_by_id(session, transcript.id)
async def test_multiple_speakers_in_webvtt(self, session):
async def test_multiple_speakers_in_webvtt(self, db_db_session):
"""Test WebVTT generation with multiple speakers."""
# Using global transcripts_controller
@@ -231,7 +233,7 @@ class TestWebVTTAutoUpdate:
await transcripts_controller.update(session, transcript, values)
# Fetch from DB
result = await session.execute(
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
row = result.scalar_one_or_none()

3170
server/uv.lock generated

File diff suppressed because it is too large Load Diff