test: update test suite for SQLAlchemy 2.0 migration

- Add session fixture for async session management
- Update all test files to use session parameter
- Convert Core-style queries to ORM-style in tests
- Fix controller calls to include session parameter
- Remove obsolete get_database() references

Test progress: 108/195 tests passing
This commit is contained in:
2025-09-18 12:35:51 -06:00
parent 06639d4d8f
commit 45d1608950
8 changed files with 163 additions and 131 deletions

View File

@@ -84,6 +84,14 @@ async def setup_database(postgres_service):
await async_engine.dispose() await async_engine.dispose()
@pytest.fixture
async def session():
from reflector.db import get_session_factory
async with get_session_factory()() as session:
yield session
@pytest.fixture @pytest.fixture
def dummy_processors(): def dummy_processors():
with ( with (

View File

@@ -18,6 +18,7 @@ async def test_attendee_parsing_bug():
""" """
# Create a test room # Create a test room
room = await rooms_controller.add( room = await rooms_controller.add(
session,
name="test-room", name="test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,

View File

@@ -34,7 +34,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts():
user_id=None, # Anonymous user_id=None, # Anonymous
) )
# Manually update created_at to be old # Manually update created_at to be old
from reflector.db import get_database # Removed get_database import
from reflector.db.transcripts import transcripts from reflector.db.transcripts import transcripts
await get_database().execute( await get_database().execute(
@@ -89,7 +89,7 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_deletes_associated_meeting_and_recording(): async def test_cleanup_deletes_associated_meeting_and_recording():
"""Test that meetings and recordings associated with old transcripts are deleted.""" """Test that meetings and recordings associated with old transcripts are deleted."""
from reflector.db import get_database # Removed get_database import
from reflector.db.meetings import meetings from reflector.db.meetings import meetings
from reflector.db.transcripts import transcripts from reflector.db.transcripts import transcripts
@@ -184,7 +184,7 @@ async def test_cleanup_handles_errors_gracefully():
) )
# Update created_at to be old # Update created_at to be old
from reflector.db import get_database # Removed get_database import
from reflector.db.transcripts import transcripts from reflector.db.transcripts import transcripts
for t_id in [transcript1.id, transcript2.id]: for t_id in [transcript1.id, transcript2.id]:
@@ -223,7 +223,7 @@ async def test_cleanup_handles_errors_gracefully():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_meeting_consent_cascade_delete(): async def test_meeting_consent_cascade_delete():
"""Test that meeting_consent records are automatically deleted when meeting is deleted.""" """Test that meeting_consent records are automatically deleted when meeting is deleted."""
from reflector.db import get_database # Removed get_database import
from reflector.db.meetings import ( from reflector.db.meetings import (
meeting_consent, meeting_consent,
meeting_consent_controller, meeting_consent_controller,

View File

@@ -4,9 +4,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from icalendar import Calendar, Event from icalendar import Calendar, Event
from reflector.db import get_database
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.rooms import rooms, rooms_controller from reflector.db.rooms import rooms_controller
from reflector.services.ics_sync import ics_sync_service from reflector.services.ics_sync import ics_sync_service
from reflector.worker.ics_sync import ( from reflector.worker.ics_sync import (
_should_sync, _should_sync,
@@ -15,8 +14,9 @@ from reflector.worker.ics_sync import (
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_room_ics_task(): async def test_sync_room_ics_task(session):
room = await rooms_controller.add( room = await rooms_controller.add(
session,
name="task-test-room", name="task-test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -52,14 +52,15 @@ async def test_sync_room_ics_task():
# Call the service directly instead of the Celery task to avoid event loop issues # Call the service directly instead of the Celery task to avoid event loop issues
await ics_sync_service.sync_room_calendar(room) await ics_sync_service.sync_room_calendar(room)
events = await calendar_events_controller.get_by_room(room.id) events = await calendar_events_controller.get_by_room(session, room.id)
assert len(events) == 1 assert len(events) == 1
assert events[0].ics_uid == "task-event-1" assert events[0].ics_uid == "task-event-1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_room_ics_disabled(): async def test_sync_room_ics_disabled(session):
room = await rooms_controller.add( room = await rooms_controller.add(
session,
name="disabled-room", name="disabled-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -76,13 +77,14 @@ async def test_sync_room_ics_disabled():
# Test that disabled rooms are skipped by the service # Test that disabled rooms are skipped by the service
result = await ics_sync_service.sync_room_calendar(room) result = await ics_sync_service.sync_room_calendar(room)
events = await calendar_events_controller.get_by_room(room.id) events = await calendar_events_controller.get_by_room(session, room.id)
assert len(events) == 0 assert len(events) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_all_ics_calendars(): async def test_sync_all_ics_calendars(session):
room1 = await rooms_controller.add( room1 = await rooms_controller.add(
session,
name="sync-all-1", name="sync-all-1",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -98,6 +100,7 @@ async def test_sync_all_ics_calendars():
) )
room2 = await rooms_controller.add( room2 = await rooms_controller.add(
session,
name="sync-all-2", name="sync-all-2",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -113,6 +116,7 @@ async def test_sync_all_ics_calendars():
) )
room3 = await rooms_controller.add( room3 = await rooms_controller.add(
session,
name="sync-all-3", name="sync-all-3",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -163,10 +167,11 @@ async def test_should_sync_logic():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_respects_fetch_interval(): async def test_sync_respects_fetch_interval(session):
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
room1 = await rooms_controller.add( room1 = await rooms_controller.add(
session,
name="interval-test-1", name="interval-test-1",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -183,11 +188,13 @@ async def test_sync_respects_fetch_interval():
) )
await rooms_controller.update( await rooms_controller.update(
session,
room1, room1,
{"ics_last_sync": now - timedelta(seconds=100)}, {"ics_last_sync": now - timedelta(seconds=100)},
) )
room2 = await rooms_controller.add( room2 = await rooms_controller.add(
session,
name="interval-test-2", name="interval-test-2",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -226,8 +233,9 @@ async def test_sync_respects_fetch_interval():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_handles_errors_gracefully(): async def test_sync_handles_errors_gracefully(session):
room = await rooms_controller.add( room = await rooms_controller.add(
session,
name="error-task-room", name="error-task-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -251,5 +259,5 @@ async def test_sync_handles_errors_gracefully():
result = await ics_sync_service.sync_room_calendar(room) result = await ics_sync_service.sync_room_calendar(room)
assert result["status"] == "error" assert result["status"] == "error"
events = await calendar_events_controller.get_by_room(room.id) events = await calendar_events_controller.get_by_room(session, room.id)
assert len(events) == 0 assert len(events) == 0

View File

@@ -5,37 +5,38 @@ from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from sqlalchemy import delete, insert
from reflector.db import get_database from reflector.db.base import TranscriptModel
from reflector.db.search import ( from reflector.db.search import (
SearchController, SearchController,
SearchParameters, SearchParameters,
SearchResult, SearchResult,
search_controller, search_controller,
) )
from reflector.db.transcripts import SourceKind, transcripts from reflector.db.transcripts import SourceKind
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_postgresql_only(): async def test_search_postgresql_only(session):
params = SearchParameters(query_text="any query here") params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert results == [] assert results == []
assert total == 0 assert total == 0
params_empty = SearchParameters(query_text=None) params_empty = SearchParameters(query_text=None)
results_empty, total_empty = await search_controller.search_transcripts( results_empty, total_empty = await search_controller.search_transcripts(
params_empty session, params_empty
) )
assert isinstance(results_empty, list) assert isinstance(results_empty, list)
assert isinstance(total_empty, int) assert isinstance(total_empty, int)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_empty_query(): async def test_search_with_empty_query(session):
"""Test that empty query returns all transcripts.""" """Test that empty query returns all transcripts."""
params = SearchParameters(query_text=None) params = SearchParameters(query_text=None)
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert isinstance(results, list) assert isinstance(results, list)
assert isinstance(total, int) assert isinstance(total, int)
@@ -45,13 +46,13 @@ async def test_search_with_empty_query():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_transcript_title_only_match(): async def test_empty_transcript_title_only_match(session):
"""Test that transcripts with title-only matches return empty snippets.""" """Test that transcripts with title-only matches return empty snippets."""
test_id = "test-empty-9b3f2a8d" test_id = "test-empty-9b3f2a8d"
try: try:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
test_data = { test_data = {
@@ -77,10 +78,11 @@ async def test_empty_transcript_title_only_match():
"user_id": "test-user-1", "user_id": "test-user-1",
} }
await get_database().execute(transcripts.insert().values(**test_data)) await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
params = SearchParameters(query_text="empty", user_id="test-user-1") params = SearchParameters(query_text="empty", user_id="test-user-1")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
found = next((r for r in results if r.id == test_id), None) found = next((r for r in results if r.id == test_id), None)
@@ -89,20 +91,20 @@ async def test_empty_transcript_title_only_match():
assert found.total_match_count == 0 assert found.total_match_count == 0
finally: finally:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
await get_database().disconnect() await session.commit()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_long_summary(): async def test_search_with_long_summary(session):
"""Test that long_summary content is searchable.""" """Test that long_summary content is searchable."""
test_id = "test-long-summary-8a9f3c2d" test_id = "test-long-summary-8a9f3c2d"
try: try:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
test_data = { test_data = {
@@ -131,10 +133,11 @@ Basic meeting content without special keywords.""",
"user_id": "test-user-2", "user_id": "test-user-2",
} }
await get_database().execute(transcripts.insert().values(**test_data)) await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
params = SearchParameters(query_text="quantum computing", user_id="test-user-2") params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
@@ -146,19 +149,19 @@ Basic meeting content without special keywords.""",
assert "quantum computing" in test_result.search_snippets[0].lower() assert "quantum computing" in test_result.search_snippets[0].lower()
finally: finally:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
await get_database().disconnect() await session.commit()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgresql_search_with_data(): async def test_postgresql_search_with_data(session):
test_id = "test-search-e2e-7f3a9b2c" test_id = "test-search-e2e-7f3a9b2c"
try: try:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
test_data = { test_data = {
@@ -196,16 +199,17 @@ We need to implement PostgreSQL tsvector for better performance.""",
"user_id": "test-user-3", "user_id": "test-user-3",
} }
await get_database().execute(transcripts.insert().values(**test_data)) await session.execute(insert(TranscriptModel).values(**test_data))
await session.commit()
params = SearchParameters(query_text="planning", user_id="test-user-3") params = SearchParameters(query_text="planning", user_id="test-user-3")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word" assert found, "Should find test transcript by title word"
params = SearchParameters(query_text="tsvector", user_id="test-user-3") params = SearchParameters(query_text="tsvector", user_id="test-user-3")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content" assert found, "Should find test transcript by webvtt content"
@@ -213,7 +217,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters( params = SearchParameters(
query_text="engineering planning", user_id="test-user-3" query_text="engineering planning", user_id="test-user-3"
) )
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words" assert found, "Should find test transcript by multiple words"
@@ -228,7 +232,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters( params = SearchParameters(
query_text="tsvector OR nosuchword", user_id="test-user-3" query_text="tsvector OR nosuchword", user_id="test-user-3"
) )
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query" assert found, "Should find test transcript with OR query"
@@ -236,16 +240,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters( params = SearchParameters(
query_text='"full-text search"', user_id="test-user-3" query_text='"full-text search"', user_id="test-user-3"
) )
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase" assert found, "Should find test transcript by exact phrase"
finally: finally:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
await get_database().disconnect() await session.commit()
@pytest.fixture @pytest.fixture
@@ -316,7 +320,7 @@ class TestSearchControllerFilters:
controller = SearchController() controller = SearchController()
with ( with (
patch("reflector.db.search.is_postgresql", return_value=True), patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db, patch("reflector.db.search.get_session_factory") as mock_session_factory,
): ):
mock_db.return_value.fetch_all = AsyncMock(return_value=[]) mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0) mock_db.return_value.fetch_val = AsyncMock(return_value=0)
@@ -336,7 +340,7 @@ class TestSearchControllerFilters:
controller = SearchController() controller = SearchController()
with ( with (
patch("reflector.db.search.is_postgresql", return_value=True), patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db, patch("reflector.db.search.get_session_factory") as mock_session_factory,
): ):
mock_db.return_value.fetch_all = AsyncMock(return_value=[]) mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0) mock_db.return_value.fetch_val = AsyncMock(return_value=0)
@@ -358,7 +362,7 @@ class TestSearchControllerFilters:
controller = SearchController() controller = SearchController()
with ( with (
patch("reflector.db.search.is_postgresql", return_value=True), patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db, patch("reflector.db.search.get_session_factory") as mock_session_factory,
): ):
class MockRow: class MockRow:

View File

@@ -4,21 +4,21 @@ import json
from datetime import datetime, timezone from datetime import datetime, timezone
import pytest import pytest
from sqlalchemy import delete, insert
from reflector.db import get_database from reflector.db.base import TranscriptModel
from reflector.db.search import SearchParameters, search_controller from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_long_summary_snippet_prioritization(): async def test_long_summary_snippet_prioritization(session):
"""Test that snippets from long_summary are prioritized over webvtt content.""" """Test that snippets from long_summary are prioritized over webvtt content."""
test_id = "test-snippet-priority-3f9a2b8c" test_id = "test-snippet-priority-3f9a2b8c"
try: try:
# Clean up any existing test data # Clean up any existing test data
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
test_data = { test_data = {
@@ -57,11 +57,11 @@ We need to consider various implementation approaches.""",
"user_id": "test-user-priority", "user_id": "test-user-priority",
} }
await get_database().execute(transcripts.insert().values(**test_data)) await session.execute(insert(TranscriptModel).values(**test_data))
# Search for "robotics" which appears in both long_summary and webvtt # Search for "robotics" which appears in both long_summary and webvtt
params = SearchParameters(query_text="robotics", user_id="test-user-priority") params = SearchParameters(query_text="robotics", user_id="test-user-priority")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
assert total >= 1 assert total >= 1
test_result = next((r for r in results if r.id == test_id), None) test_result = next((r for r in results if r.id == test_id), None)
@@ -86,20 +86,20 @@ We need to consider various implementation approaches.""",
), f"Snippet should contain search term: {snippet}" ), f"Snippet should contain search term: {snippet}"
finally: finally:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
await get_database().disconnect() await session.commit()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_long_summary_only_search(): async def test_long_summary_only_search(session):
"""Test searching for content that only exists in long_summary.""" """Test searching for content that only exists in long_summary."""
test_id = "test-long-only-8b3c9f2a" test_id = "test-long-only-8b3c9f2a"
try: try:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
test_data = { test_data = {
@@ -135,11 +135,11 @@ Discussion of timeline and deliverables.""",
"user_id": "test-user-long", "user_id": "test-user-long",
} }
await get_database().execute(transcripts.insert().values(**test_data)) await session.execute(insert(TranscriptModel).values(**test_data))
# Search for terms only in long_summary # Search for terms only in long_summary
params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long") params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(session, params)
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find transcript by long_summary-only content" assert found, "Should find transcript by long_summary-only content"
@@ -154,13 +154,13 @@ Discussion of timeline and deliverables.""",
# Search for "yield farming" - a more specific term # Search for "yield farming" - a more specific term
params2 = SearchParameters(query_text="yield farming", user_id="test-user-long") params2 = SearchParameters(query_text="yield farming", user_id="test-user-long")
results2, total2 = await search_controller.search_transcripts(params2) results2, total2 = await search_controller.search_transcripts(session, params2)
found2 = any(r.id == test_id for r in results2) found2 = any(r.id == test_id for r in results2)
assert found2, "Should find transcript by specific long_summary phrase" assert found2, "Should find transcript by specific long_summary phrase"
finally: finally:
await get_database().execute( await session.execute(
transcripts.delete().where(transcripts.c.id == test_id) delete(TranscriptModel).where(TranscriptModel.id == test_id)
) )
await get_database().disconnect() await session.commit()

View File

@@ -53,7 +53,8 @@ def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker)
import threading import threading
from reflector.app import app from reflector.app import app
from reflector.db import get_database
# Database connection handled by SQLAlchemy engine
from reflector.settings import settings from reflector.settings import settings
DATA_DIR = settings.DATA_DIR DATA_DIR = settings.DATA_DIR
@@ -77,13 +78,8 @@ def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker)
server_instance = Server(config) server_instance = Server(config)
async def start_server(): async def start_server():
# Initialize database connection in this event loop # Database connections managed by SQLAlchemy engine
database = get_database() await server_instance.serve()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting # Signal that server is starting
server_started.set() server_started.set()

View File

@@ -1,13 +1,13 @@
"""Integration tests for WebVTT auto-update functionality in Transcript model.""" """Integration tests for WebVTT auto-update functionality in Transcript model."""
import pytest import pytest
from sqlalchemy import select
from reflector.db import get_database from reflector.db.base import TranscriptModel
from reflector.db.transcripts import ( from reflector.db.transcripts import (
SourceKind, SourceKind,
TranscriptController, TranscriptController,
TranscriptTopic, TranscriptTopic,
transcripts,
) )
from reflector.processors.types import Word from reflector.processors.types import Word
@@ -16,30 +16,35 @@ from reflector.processors.types import Word
class TestWebVTTAutoUpdate: class TestWebVTTAutoUpdate:
"""Test that WebVTT field auto-updates when Transcript is created or modified.""" """Test that WebVTT field auto-updates when Transcript is created or modified."""
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self): async def test_webvtt_not_updated_on_transcript_creation_without_topics(
self, session
):
"""WebVTT should be None when creating transcript without topics.""" """WebVTT should be None when creating transcript without topics."""
controller = TranscriptController() # Using global transcripts_controller
transcript = await controller.add( transcript = await transcripts_controller.add(
session,
name="Test Transcript", name="Test Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
) )
try: try:
result = await get_database().fetch_one( result = await session.execute(
transcripts.select().where(transcripts.c.id == transcript.id) select(TranscriptModel).where(TranscriptModel.id == transcript.id)
) )
row = result.scalar_one_or_none()
assert result is not None assert row is not None
assert result["webvtt"] is None assert row.webvtt is None
finally: finally:
await controller.remove_by_id(transcript.id) await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_updated_on_upsert_topic(self): async def test_webvtt_updated_on_upsert_topic(self, session):
"""WebVTT should update when upserting topics via upsert_topic method.""" """WebVTT should update when upserting topics via upsert_topic method."""
controller = TranscriptController() # Using global transcripts_controller
transcript = await controller.add( transcript = await transcripts_controller.add(
session,
name="Test Transcript", name="Test Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
) )
@@ -56,14 +61,15 @@ class TestWebVTTAutoUpdate:
], ],
) )
await controller.upsert_topic(transcript, topic) await transcripts_controller.upsert_topic(session, transcript, topic)
result = await get_database().fetch_one( result = await session.execute(
transcripts.select().where(transcripts.c.id == transcript.id) select(TranscriptModel).where(TranscriptModel.id == transcript.id)
) )
row = result.scalar_one_or_none()
assert result is not None assert row is not None
webvtt = result["webvtt"] webvtt = row.webvtt
assert webvtt is not None assert webvtt is not None
assert "WEBVTT" in webvtt assert "WEBVTT" in webvtt
@@ -71,13 +77,14 @@ class TestWebVTTAutoUpdate:
assert "<v Speaker0>" in webvtt assert "<v Speaker0>" in webvtt
finally: finally:
await controller.remove_by_id(transcript.id) await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_updated_on_direct_topics_update(self): async def test_webvtt_updated_on_direct_topics_update(self, session):
"""WebVTT should update when updating topics field directly.""" """WebVTT should update when updating topics field directly."""
controller = TranscriptController() # Using global transcripts_controller
transcript = await controller.add( transcript = await transcripts_controller.add(
session,
name="Test Transcript", name="Test Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
) )
@@ -96,28 +103,32 @@ class TestWebVTTAutoUpdate:
} }
] ]
await controller.update(transcript, {"topics": topics_data}) await transcripts_controller.update(
session, transcript, {"topics": topics_data}
# Fetch from DB
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
) )
assert result is not None # Fetch from DB
webvtt = result["webvtt"] result = await session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
row = result.scalar_one_or_none()
assert row is not None
webvtt = row.webvtt
assert webvtt is not None assert webvtt is not None
assert "WEBVTT" in webvtt assert "WEBVTT" in webvtt
assert "First sentence" in webvtt assert "First sentence" in webvtt
finally: finally:
await controller.remove_by_id(transcript.id) await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_updated_manually_with_handle_topics_update(self): async def test_webvtt_updated_manually_with_handle_topics_update(self, session):
"""Test that _handle_topics_update works when called manually.""" """Test that _handle_topics_update works when called manually."""
controller = TranscriptController() # Using global transcripts_controller
transcript = await controller.add( transcript = await transcripts_controller.add(
session,
name="Test Transcript", name="Test Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
) )
@@ -138,15 +149,16 @@ class TestWebVTTAutoUpdate:
values = {"topics": transcript.topics_dump()} values = {"topics": transcript.topics_dump()}
await controller.update(transcript, values) await transcripts_controller.update(session, transcript, values)
# Fetch from DB # Fetch from DB
result = await get_database().fetch_one( result = await session.execute(
transcripts.select().where(transcripts.c.id == transcript.id) select(TranscriptModel).where(TranscriptModel.id == transcript.id)
) )
row = result.scalar_one_or_none()
assert result is not None assert row is not None
webvtt = result["webvtt"] webvtt = row.webvtt
assert webvtt is not None assert webvtt is not None
assert "WEBVTT" in webvtt assert "WEBVTT" in webvtt
@@ -154,13 +166,14 @@ class TestWebVTTAutoUpdate:
assert "<v Speaker0>" in webvtt assert "<v Speaker0>" in webvtt
finally: finally:
await controller.remove_by_id(transcript.id) await transcripts_controller.remove_by_id(session, transcript.id)
async def test_webvtt_update_with_non_sequential_topics_fails(self): async def test_webvtt_update_with_non_sequential_topics_fails(self, session):
"""Test that non-sequential topics raise assertion error.""" """Test that non-sequential topics raise assertion error."""
controller = TranscriptController() # Using global transcripts_controller
transcript = await controller.add( transcript = await transcripts_controller.add(
session,
name="Test Transcript", name="Test Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
) )
@@ -186,13 +199,14 @@ class TestWebVTTAutoUpdate:
assert "Words are not in sequence" in str(exc_info.value) assert "Words are not in sequence" in str(exc_info.value)
finally: finally:
await controller.remove_by_id(transcript.id) await transcripts_controller.remove_by_id(session, transcript.id)
async def test_multiple_speakers_in_webvtt(self): async def test_multiple_speakers_in_webvtt(self, session):
"""Test WebVTT generation with multiple speakers.""" """Test WebVTT generation with multiple speakers."""
controller = TranscriptController() # Using global transcripts_controller
transcript = await controller.add( transcript = await transcripts_controller.add(
session,
name="Test Transcript", name="Test Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
) )
@@ -213,15 +227,16 @@ class TestWebVTTAutoUpdate:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
values = {"topics": transcript.topics_dump()} values = {"topics": transcript.topics_dump()}
await controller.update(transcript, values) await transcripts_controller.update(session, transcript, values)
# Fetch from DB # Fetch from DB
result = await get_database().fetch_one( result = await session.execute(
transcripts.select().where(transcripts.c.id == transcript.id) select(TranscriptModel).where(TranscriptModel.id == transcript.id)
) )
row = result.scalar_one_or_none()
assert result is not None assert row is not None
webvtt = result["webvtt"] webvtt = row.webvtt
assert webvtt is not None assert webvtt is not None
assert "<v Speaker0>" in webvtt assert "<v Speaker0>" in webvtt
@@ -231,4 +246,4 @@ class TestWebVTTAutoUpdate:
assert "Goodbye" in webvtt assert "Goodbye" in webvtt
finally: finally:
await controller.remove_by_id(transcript.id) await transcripts_controller.remove_by_id(session, transcript.id)