test: update test suite for SQLAlchemy 2.0 migration

- Add session fixture for async session management
- Update all test files to use session parameter
- Convert Core-style queries to ORM-style in tests
- Fix controller calls to include session parameter
- Remove obsolete get_database() references

Test progress: 108/195 tests passing
This commit is contained in:
2025-09-18 12:35:51 -06:00
parent 06639d4d8f
commit 45d1608950
8 changed files with 163 additions and 131 deletions

View File

@@ -5,37 +5,38 @@ from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import delete, insert
from reflector.db import get_database
from reflector.db.base import TranscriptModel
from reflector.db.search import (
SearchController,
SearchParameters,
SearchResult,
search_controller,
)
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.transcripts import SourceKind
@pytest.mark.asyncio
async def test_search_postgresql_only():
async def test_search_postgresql_only(session):
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert results == []
assert total == 0
params_empty = SearchParameters(query_text=None)
results_empty, total_empty = await search_controller.search_transcripts(
params_empty
session, params_empty
)
assert isinstance(results_empty, list)
assert isinstance(total_empty, int)
@pytest.mark.asyncio
async def test_search_with_empty_query():
async def test_search_with_empty_query(session):
"""Test that empty query returns all transcripts."""
params = SearchParameters(query_text=None)
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert isinstance(results, list)
assert isinstance(total, int)
@@ -45,13 +46,13 @@ async def test_search_with_empty_query():
@pytest.mark.asyncio
async def test_empty_transcript_title_only_match():
async def test_empty_transcript_title_only_match(session):
"""Test that transcripts with title-only matches return empty snippets."""
test_id = "test-empty-9b3f2a8d"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
await session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
test_data = {
@@ -77,10 +78,11 @@ async def test_empty_transcript_title_only_match():
"user_id": "test-user-1",
}
await get_database().execute(transcripts.insert().values(**test_data))
await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
params = SearchParameters(query_text="empty", user_id="test-user-1")
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert total >= 1
found = next((r for r in results if r.id == test_id), None)
@@ -89,20 +91,20 @@ async def test_empty_transcript_title_only_match():
assert found.total_match_count == 0
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
await session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await get_database().disconnect()
await session.commit()
@pytest.mark.asyncio
async def test_search_with_long_summary():
async def test_search_with_long_summary(session):
"""Test that long_summary content is searchable."""
test_id = "test-long-summary-8a9f3c2d"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
await session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
test_data = {
@@ -131,10 +133,11 @@ Basic meeting content without special keywords.""",
"user_id": "test-user-2",
}
await get_database().execute(transcripts.insert().values(**test_data))
await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
@@ -146,19 +149,19 @@ Basic meeting content without special keywords.""",
assert "quantum computing" in test_result.search_snippets[0].lower()
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
await session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await get_database().disconnect()
await session.commit()
@pytest.mark.asyncio
async def test_postgresql_search_with_data():
async def test_postgresql_search_with_data(session):
test_id = "test-search-e2e-7f3a9b2c"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
await session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
test_data = {
@@ -196,16 +199,17 @@ We need to implement PostgreSQL tsvector for better performance.""",
"user_id": "test-user-3",
}
await get_database().execute(transcripts.insert().values(**test_data))
await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
params = SearchParameters(query_text="planning", user_id="test-user-3")
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
params = SearchParameters(query_text="tsvector", user_id="test-user-3")
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
@@ -213,7 +217,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text="engineering planning", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
@@ -228,7 +232,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text="tsvector OR nosuchword", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
@@ -236,16 +240,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text='"full-text search"', user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(params)
results, total = await search_controller.search_transcripts(session, params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase"
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
await session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
)
await get_database().disconnect()
await session.commit()
@pytest.fixture
@@ -316,7 +320,7 @@ class TestSearchControllerFilters:
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
patch("reflector.db.search.get_session_factory") as mock_session_factory,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
@@ -336,7 +340,7 @@ class TestSearchControllerFilters:
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
patch("reflector.db.search.get_session_factory") as mock_session_factory,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
@@ -358,7 +362,7 @@ class TestSearchControllerFilters:
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
patch("reflector.db.search.get_session_factory") as mock_session_factory,
):
class MockRow: