diff --git a/server/reflector/views/rooms.py b/server/reflector/views/rooms.py index ffb8cce8..c0b6223a 100644 --- a/server/reflector/views/rooms.py +++ b/server/reflector/views/rooms.py @@ -11,7 +11,7 @@ from redis.exceptions import LockError from sqlalchemy.ext.asyncio import AsyncSession import reflector.auth as auth -from reflector.db import get_session, get_session_factory +from reflector.db import get_session from reflector.db.calendar_events import calendar_events_controller from reflector.db.meetings import meetings_controller from reflector.db.rooms import rooms_controller @@ -177,18 +177,17 @@ def parse_datetime_with_timezone(iso_string: str) -> datetime: @router.get("/rooms", response_model=Page[RoomDetails]) async def rooms_list( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ) -> list[RoomDetails]: if not user and not settings.PUBLIC_MODE: raise HTTPException(status_code=401, detail="Not authenticated") user_id = user["sub"] if user else None - session_factory = get_session_factory() - async with session_factory() as session: - query = await rooms_controller.get_all( - session, user_id=user_id, order_by="-created_at", return_query=True - ) - return await paginate(session, query) + query = await rooms_controller.get_all( + session, user_id=user_id, order_by="-created_at", return_query=True + ) + return await paginate(session, query) @router.get("/rooms/{room_id}", response_model=RoomDetails) diff --git a/server/tests/test_pipeline_main_file.py b/server/tests/test_pipeline_main_file.py index 49c2d22c..1d7f1ade 100644 --- a/server/tests/test_pipeline_main_file.py +++ b/server/tests/test_pipeline_main_file.py @@ -624,10 +624,11 @@ async def test_pipeline_file_process_no_transcript(): # Should raise an exception for missing transcript when get_transcript is called with pytest.raises(Exception, match="Transcript not found"): - from reflector.db import get_session_factory + # Use a mock session - the controller is mocked to return None anyway + from unittest.mock import MagicMock - async with get_session_factory()() as session: - await pipeline.get_transcript(session) + mock_session = MagicMock() + await pipeline.get_transcript(mock_session) @pytest.mark.asyncio