mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
fix: improve session management and testing infrastructure
- Split get_session into _get_session and get_session to facilitate test mocking - Add autouse fixture to ensure db_session is properly injected in tests - Fix generate_waveform method to accept session parameter explicitly
This commit is contained in:
@@ -38,11 +38,17 @@ def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
# necessary implementation to ease mocking on pytest
|
||||
async with get_session_factory()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async for session in _get_session():
|
||||
yield session
|
||||
|
||||
|
||||
import reflector.db.calendar_events # noqa
|
||||
import reflector.db.meetings # noqa
|
||||
import reflector.db.recordings # noqa
|
||||
|
||||
@@ -13,6 +13,7 @@ from pathlib import Path
|
||||
import av
|
||||
import structlog
|
||||
from celery import chain, shared_task
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reflector.asynctask import asynctask
|
||||
from reflector.db import get_session_factory
|
||||
@@ -317,12 +318,9 @@ class PipelineMainFile(PipelineMainBase):
|
||||
self.logger.error(f"Diarization failed: {e}")
|
||||
return None
|
||||
|
||||
async def generate_waveform(self, audio_path: Path):
|
||||
async def generate_waveform(self, session: AsyncSession, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
async with get_session_factory()() as session:
|
||||
transcript = await transcripts_controller.get_by_id(
|
||||
session, self.transcript_id
|
||||
)
|
||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
|
||||
@@ -344,6 +344,15 @@ def celery_includes():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def ensure_db_session_in_app(db_session):
|
||||
async def mock_get_session():
|
||||
yield db_session
|
||||
|
||||
with patch("reflector.db._get_session", side_effect=mock_get_session):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_session):
|
||||
from httpx import AsyncClient
|
||||
|
||||
Reference in New Issue
Block a user