mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
test: update test suite for SQLAlchemy 2.0 migration
- Add session fixture for async session management - Update all test files to use session parameter - Convert Core-style queries to ORM-style in tests - Fix controller calls to include session parameter - Remove obsolete get_database() references Test progress: 108/195 tests passing
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.base import TranscriptModel
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
TranscriptController,
|
||||
TranscriptTopic,
|
||||
transcripts,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
@@ -16,30 +16,35 @@ from reflector.processors.types import Word
|
||||
class TestWebVTTAutoUpdate:
|
||||
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
|
||||
|
||||
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self):
|
||||
async def test_webvtt_not_updated_on_transcript_creation_without_topics(
|
||||
self, session
|
||||
):
|
||||
"""WebVTT should be None when creating transcript without topics."""
|
||||
controller = TranscriptController()
|
||||
# Using global transcripts_controller
|
||||
|
||||
transcript = await controller.add(
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await get_database().fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
result = await session.execute(
|
||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
|
||||
assert result is not None
|
||||
assert result["webvtt"] is None
|
||||
assert row is not None
|
||||
assert row.webvtt is None
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(session, transcript.id)
|
||||
|
||||
async def test_webvtt_updated_on_upsert_topic(self):
|
||||
async def test_webvtt_updated_on_upsert_topic(self, session):
|
||||
"""WebVTT should update when upserting topics via upsert_topic method."""
|
||||
controller = TranscriptController()
|
||||
# Using global transcripts_controller
|
||||
|
||||
transcript = await controller.add(
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
@@ -56,14 +61,15 @@ class TestWebVTTAutoUpdate:
|
||||
],
|
||||
)
|
||||
|
||||
await controller.upsert_topic(transcript, topic)
|
||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
||||
|
||||
result = await get_database().fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
result = await session.execute(
|
||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
assert row is not None
|
||||
webvtt = row.webvtt
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
@@ -71,13 +77,14 @@ class TestWebVTTAutoUpdate:
|
||||
assert "<v Speaker0>" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(session, transcript.id)
|
||||
|
||||
async def test_webvtt_updated_on_direct_topics_update(self):
|
||||
async def test_webvtt_updated_on_direct_topics_update(self, session):
|
||||
"""WebVTT should update when updating topics field directly."""
|
||||
controller = TranscriptController()
|
||||
# Using global transcripts_controller
|
||||
|
||||
transcript = await controller.add(
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
@@ -96,28 +103,32 @@ class TestWebVTTAutoUpdate:
|
||||
}
|
||||
]
|
||||
|
||||
await controller.update(transcript, {"topics": topics_data})
|
||||
|
||||
# Fetch from DB
|
||||
result = await get_database().fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
await transcripts_controller.update(
|
||||
session, transcript, {"topics": topics_data}
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
# Fetch from DB
|
||||
result = await session.execute(
|
||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
|
||||
assert row is not None
|
||||
webvtt = row.webvtt
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
assert "First sentence" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(session, transcript.id)
|
||||
|
||||
async def test_webvtt_updated_manually_with_handle_topics_update(self):
|
||||
async def test_webvtt_updated_manually_with_handle_topics_update(self, session):
|
||||
"""Test that _handle_topics_update works when called manually."""
|
||||
controller = TranscriptController()
|
||||
# Using global transcripts_controller
|
||||
|
||||
transcript = await controller.add(
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
@@ -138,15 +149,16 @@ class TestWebVTTAutoUpdate:
|
||||
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
await controller.update(transcript, values)
|
||||
await transcripts_controller.update(session, transcript, values)
|
||||
|
||||
# Fetch from DB
|
||||
result = await get_database().fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
result = await session.execute(
|
||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
assert row is not None
|
||||
webvtt = row.webvtt
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
@@ -154,13 +166,14 @@ class TestWebVTTAutoUpdate:
|
||||
assert "<v Speaker0>" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(session, transcript.id)
|
||||
|
||||
async def test_webvtt_update_with_non_sequential_topics_fails(self):
|
||||
async def test_webvtt_update_with_non_sequential_topics_fails(self, session):
|
||||
"""Test that non-sequential topics raise assertion error."""
|
||||
controller = TranscriptController()
|
||||
# Using global transcripts_controller
|
||||
|
||||
transcript = await controller.add(
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
@@ -186,13 +199,14 @@ class TestWebVTTAutoUpdate:
|
||||
assert "Words are not in sequence" in str(exc_info.value)
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(session, transcript.id)
|
||||
|
||||
async def test_multiple_speakers_in_webvtt(self):
|
||||
async def test_multiple_speakers_in_webvtt(self, session):
|
||||
"""Test WebVTT generation with multiple speakers."""
|
||||
controller = TranscriptController()
|
||||
# Using global transcripts_controller
|
||||
|
||||
transcript = await controller.add(
|
||||
transcript = await transcripts_controller.add(
|
||||
session,
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
@@ -213,15 +227,16 @@ class TestWebVTTAutoUpdate:
|
||||
transcript.upsert_topic(topic)
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
await controller.update(transcript, values)
|
||||
await transcripts_controller.update(session, transcript, values)
|
||||
|
||||
# Fetch from DB
|
||||
result = await get_database().fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
result = await session.execute(
|
||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
assert row is not None
|
||||
webvtt = row.webvtt
|
||||
|
||||
assert webvtt is not None
|
||||
assert "<v Speaker0>" in webvtt
|
||||
@@ -231,4 +246,4 @@ class TestWebVTTAutoUpdate:
|
||||
assert "Goodbye" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(session, transcript.id)
|
||||
|
||||
Reference in New Issue
Block a user