From e0c71c5548dc5bc3ebfa91356108d63006820156 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 23 Sep 2025 16:46:37 -0600 Subject: [PATCH] refactor: migrate to SQLAlchemy 2.0 ORM-style patterns - Replace __table__.join() with ORM-style joins using select_from().outerjoin() - Replace __table__.delete() with delete(Model) in tests - Migrate from **row.__dict__ to model_validate() with ConfigDict(from_attributes=True) - Add ConfigDict(from_attributes=True) to all Pydantic models for proper SQLAlchemy model conversion - Update all controller methods to use model_validate() instead of dict unpacking This completes the migration to SQLAlchemy 2.0 recommended patterns while maintaining backwards compatibility and improving code consistency. --- server/reflector/db/calendar_events.py | 14 ++++++++------ server/reflector/db/meetings.py | 24 ++++++++++++++---------- server/reflector/db/recordings.py | 8 +++++--- server/reflector/db/rooms.py | 12 +++++++----- server/reflector/db/search.py | 10 ++++------ server/reflector/db/transcripts.py | 15 ++++++++++----- server/tests/test_cleanup.py | 8 +++----- 7 files changed, 51 insertions(+), 40 deletions(-) diff --git a/server/reflector/db/calendar_events.py b/server/reflector/db/calendar_events.py index 4fbcfa9b..889f18a0 100644 --- a/server/reflector/db/calendar_events.py +++ b/server/reflector/db/calendar_events.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone from typing import Any import sqlalchemy as sa -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession @@ -11,6 +11,8 @@ from reflector.utils import generate_uuid4 class CalendarEvent(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) room_id: str ics_uid: str @@ -50,7 +52,7 @@ class CalendarEventController: ) result = await session.execute(query) - return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] + return [CalendarEvent.model_validate(row) for row in result.scalars().all()] async def get_by_id( self, session: AsyncSession, event_id: str @@ -60,7 +62,7 @@ class CalendarEventController: row = result.scalar_one_or_none() if not row: return None - return CalendarEvent(**row.__dict__) + return CalendarEvent.model_validate(row) async def get_by_ics_uid( self, session: AsyncSession, room_id: str, ics_uid: str @@ -75,7 +77,7 @@ class CalendarEventController: row = result.scalar_one_or_none() if not row: return None - return CalendarEvent(**row.__dict__) + return CalendarEvent.model_validate(row) async def upsert( self, session: AsyncSession, event: CalendarEvent @@ -137,7 +139,7 @@ class CalendarEventController: if not include_deleted: query = query.where(CalendarEventModel.is_deleted == False) result = await session.execute(query) - return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] + return [CalendarEvent.model_validate(row) for row in result.scalars().all()] async def get_upcoming( self, session: AsyncSession, room_id: str, minutes_ahead: int = 120 @@ -159,7 +161,7 @@ class CalendarEventController: ) result = await session.execute(query) - return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] + return [CalendarEvent.model_validate(row) for row in result.scalars().all()] async def soft_delete_missing( self, session: AsyncSession, room_id: str, current_ics_uids: list[str] diff --git a/server/reflector/db/meetings.py b/server/reflector/db/meetings.py index 02a9ecd1..1462a7a1 100644 --- a/server/reflector/db/meetings.py +++ b/server/reflector/db/meetings.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Any, Literal import sqlalchemy as sa -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession @@ -12,6 +12,8 @@ from reflector.utils import generate_uuid4 class MeetingConsent(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) meeting_id: str user_id: str | None = None @@ -20,6 +22,8 @@ class MeetingConsent(BaseModel): class Meeting(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str room_name: str room_url: str @@ -76,7 +80,7 @@ class MeetingController: async def get_all_active(self, session: AsyncSession) -> list[Meeting]: query = select(MeetingModel).where(MeetingModel.is_active) result = await session.execute(query) - return [Meeting(**row.__dict__) for row in result.scalars().all()] + return [Meeting.model_validate(row) for row in result.scalars().all()] async def get_by_room_name( self, @@ -96,7 +100,7 @@ class MeetingController: row = result.scalar_one_or_none() if not row: return None - return Meeting(**row.__dict__) + return Meeting.model_validate(row) async def get_active( self, session: AsyncSession, room: Room, current_time: datetime @@ -120,7 +124,7 @@ class MeetingController: row = result.scalar_one_or_none() if not row: return None - return Meeting(**row.__dict__) + return Meeting.model_validate(row) async def get_all_active_for_room( self, session: AsyncSession, room: Room, current_time: datetime @@ -137,7 +141,7 @@ class MeetingController: .order_by(MeetingModel.end_date.desc()) ) result = await session.execute(query) - return [Meeting(**row.__dict__) for row in result.scalars().all()] + return [Meeting.model_validate(row) for row in result.scalars().all()] async def get_active_by_calendar_event( self, @@ -161,7 +165,7 @@ class MeetingController: row = result.scalar_one_or_none() if not row: return None - return Meeting(**row.__dict__) + return Meeting.model_validate(row) async def get_by_id( self, session: AsyncSession, meeting_id: str, **kwargs @@ -171,7 +175,7 @@ class MeetingController: row = result.scalar_one_or_none() if not row: return None - return Meeting(**row.__dict__) + return Meeting.model_validate(row) async def get_by_calendar_event( self, session: AsyncSession, calendar_event_id: str @@ -183,7 +187,7 @@ class MeetingController: row = result.scalar_one_or_none() if not row: return None - return Meeting(**row.__dict__) + return Meeting.model_validate(row) async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs): query = ( @@ -201,7 +205,7 @@ class MeetingConsentController: MeetingConsentModel.meeting_id == meeting_id ) result = await session.execute(query) - return [MeetingConsent(**row.__dict__) for row in result.scalars().all()] + return [MeetingConsent.model_validate(row) for row in result.scalars().all()] async def get_by_meeting_and_user( self, session: AsyncSession, meeting_id: str, user_id: str @@ -217,7 +221,7 @@ class MeetingConsentController: row = result.scalar_one_or_none() if row is None: return None - return MeetingConsent(**row.__dict__) + return MeetingConsent.model_validate(row) async def upsert( self, session: AsyncSession, consent: MeetingConsent diff --git a/server/reflector/db/recordings.py b/server/reflector/db/recordings.py index d5cc4030..2ba33280 100644 --- a/server/reflector/db/recordings.py +++ b/server/reflector/db/recordings.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession @@ -9,6 +9,8 @@ from reflector.utils import generate_uuid4 class Recording(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) meeting_id: str url: str @@ -53,7 +55,7 @@ class RecordingController: row = result.scalar_one_or_none() if not row: return None - return Recording(**row.__dict__) + return Recording.model_validate(row) async def get_by_meeting_id( self, session: AsyncSession, meeting_id: str @@ -63,7 +65,7 @@ class RecordingController: """ query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id) result = await session.execute(query) - return [Recording(**row.__dict__) for row in result.scalars().all()] + return [Recording.model_validate(row) for row in result.scalars().all()] async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None: """ diff --git a/server/reflector/db/rooms.py b/server/reflector/db/rooms.py index 2098d09e..e4f29631 100644 --- a/server/reflector/db/rooms.py +++ b/server/reflector/db/rooms.py @@ -4,7 +4,7 @@ from sqlite3 import IntegrityError from typing import Literal from fastapi import HTTPException -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import or_ @@ -14,6 +14,8 @@ from reflector.utils import generate_uuid4 class Room(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) name: str user_id: str @@ -70,7 +72,7 @@ class RoomController: return query result = await session.execute(query) - return [Room(**row.__dict__) for row in result.scalars().all()] + return [Room.model_validate(row) for row in result.scalars().all()] async def add( self, @@ -155,7 +157,7 @@ class RoomController: row = result.scalars().first() if not row: return None - return Room(**row.__dict__) + return Room.model_validate(row) async def get_by_name( self, session: AsyncSession, room_name: str, **kwargs @@ -170,7 +172,7 @@ class RoomController: row = result.scalars().first() if not row: return None - return Room(**row.__dict__) + return Room.model_validate(row) async def get_by_id_for_http( self, session: AsyncSession, meeting_id: str, user_id: str | None @@ -186,7 +188,7 @@ class RoomController: if not row: raise HTTPException(status_code=404, detail="Room not found") - room = Room(**row.__dict__) + room = Room.model_validate(row) return room diff --git a/server/reflector/db/search.py b/server/reflector/db/search.py index 37c7e7ad..ad8ea174 100644 --- a/server/reflector/db/search.py +++ b/server/reflector/db/search.py @@ -369,12 +369,10 @@ class SearchController: rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank") columns = base_columns + [rank_column] - base_query = sqlalchemy.select(*columns).select_from( - TranscriptModel.__table__.join( - RoomModel.__table__, - TranscriptModel.room_id == RoomModel.id, - isouter=True, - ) + base_query = ( + sqlalchemy.select(*columns) + .select_from(TranscriptModel) + .outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id) ) if params.query_text is not None: diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index c4da4805..e4fe43a7 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -103,6 +103,8 @@ class TranscriptParticipant(BaseModel): class Transcript(BaseModel): """Full transcript model with all fields.""" + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) user_id: str | None = None name: str = Field(default_factory=generate_transcript_name) @@ -317,8 +319,9 @@ class TranscriptController: query = query.where(TranscriptModel.title.ilike(f"%{search_term}%")) # Exclude heavy JSON columns from list queries + # Get all ORM column attributes except excluded ones transcript_columns = [ - col + getattr(TranscriptModel, col.name) for col in TranscriptModel.__table__.c if col.name not in exclude_columns ] @@ -361,7 +364,7 @@ class TranscriptController: row = result.scalar_one_or_none() if not row: return None - return Transcript(**row.__dict__) + return Transcript.model_validate(row) async def get_by_recording_id( self, session: AsyncSession, recording_id: str, **kwargs @@ -378,7 +381,7 @@ class TranscriptController: row = result.scalar_one_or_none() if not row: return None - return Transcript(**row.__dict__) + return Transcript.model_validate(row) async def get_by_room_id( self, session: AsyncSession, room_id: str, **kwargs @@ -396,7 +399,9 @@ class TranscriptController: field = field.desc() query = query.order_by(field) results = await session.execute(query) - return [Transcript(**dict(row)) for row in results.mappings().all()] + return [ + Transcript.model_validate(dict(row)) for row in results.mappings().all() + ] async def get_by_id_for_http( self, @@ -420,7 +425,7 @@ class TranscriptController: raise HTTPException(status_code=404, detail="Transcript not found") # if the transcript is anonymous, share mode is not checked - transcript = Transcript(**row.__dict__) + transcript = Transcript.model_validate(row) if transcript.user_id is None: return transcript diff --git a/server/tests/test_cleanup.py b/server/tests/test_cleanup.py index 5f741771..3d2ccced 100644 --- a/server/tests/test_cleanup.py +++ b/server/tests/test_cleanup.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, patch import pytest -from sqlalchemy import insert, select, update +from sqlalchemy import delete, insert, select, update from reflector.db.base import ( MeetingConsentModel, @@ -310,11 +310,9 @@ async def test_meeting_consent_cascade_delete(db_session): # Delete the transcript and meeting await db_session.execute( - TranscriptModel.__table__.delete().where(TranscriptModel.id == transcript.id) - ) - await db_session.execute( - MeetingModel.__table__.delete().where(MeetingModel.id == meeting_id) + delete(TranscriptModel).where(TranscriptModel.id == transcript.id) ) + await db_session.execute(delete(MeetingModel).where(MeetingModel.id == meeting_id)) await db_session.commit() # Verify consent entries were cascade deleted