diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index 7f2a58a3..d482c404 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -38,11 +38,17 @@ def get_session_factory() -> async_sessionmaker[AsyncSession]: return _session_factory -async def get_session() -> AsyncGenerator[AsyncSession, None]: +async def _get_session() -> AsyncGenerator[AsyncSession, None]: + # necessary implementation to ease mocking on pytest async with get_session_factory()() as session: yield session +async def get_session() -> AsyncGenerator[AsyncSession, None]: + async for session in _get_session(): + yield session + + import reflector.db.calendar_events # noqa import reflector.db.meetings # noqa import reflector.db.recordings # noqa diff --git a/server/reflector/pipelines/main_file_pipeline.py b/server/reflector/pipelines/main_file_pipeline.py index cf586350..99ade57e 100644 --- a/server/reflector/pipelines/main_file_pipeline.py +++ b/server/reflector/pipelines/main_file_pipeline.py @@ -13,6 +13,7 @@ from pathlib import Path import av import structlog from celery import chain, shared_task +from sqlalchemy.ext.asyncio import AsyncSession from reflector.asynctask import asynctask from reflector.db import get_session_factory @@ -317,12 +318,9 @@ class PipelineMainFile(PipelineMainBase): self.logger.error(f"Diarization failed: {e}") return None - async def generate_waveform(self, audio_path: Path): + async def generate_waveform(self, session: AsyncSession, audio_path: Path): """Generate and save waveform""" - async with get_session_factory()() as session: - transcript = await transcripts_controller.get_by_id( - session, self.transcript_id - ) + transcript = await transcripts_controller.get_by_id(session, self.transcript_id) processor = AudioWaveformProcessor( audio_path=audio_path, diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 84d4d3ec..2915eccf 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -344,6 +344,15 @@ def celery_includes(): ] +@pytest.fixture(autouse=True) +async def ensure_db_session_in_app(db_session): + async def mock_get_session(): + yield db_session + + with patch("reflector.db._get_session", side_effect=mock_get_session): + yield + + @pytest.fixture async def client(db_session): from httpx import AsyncClient