refactor: migrate to SQLAlchemy 2.0 ORM-style patterns

- Replace __table__.join() with ORM-style joins using select_from().outerjoin()
- Replace __table__.delete() with delete(Model) in tests
- Migrate from **row.__dict__ to model_validate() with ConfigDict(from_attributes=True)
- Add ConfigDict(from_attributes=True) to all Pydantic models for proper SQLAlchemy model conversion
- Update all controller methods to use model_validate() instead of dict unpacking

This completes the migration to SQLAlchemy 2.0 recommended patterns while maintaining
backwards compatibility and improving code consistency.
This commit is contained in:
2025-09-23 16:46:37 -06:00
parent a883df0d63
commit e0c71c5548
7 changed files with 51 additions and 40 deletions

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import delete, select, update from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -11,6 +11,8 @@ from reflector.utils import generate_uuid4
class CalendarEvent(BaseModel): class CalendarEvent(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
room_id: str room_id: str
ics_uid: str ics_uid: str
@@ -50,7 +52,7 @@ class CalendarEventController:
) )
result = await session.execute(query) result = await session.execute(query)
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
async def get_by_id( async def get_by_id(
self, session: AsyncSession, event_id: str self, session: AsyncSession, event_id: str
@@ -60,7 +62,7 @@ class CalendarEventController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return CalendarEvent(**row.__dict__) return CalendarEvent.model_validate(row)
async def get_by_ics_uid( async def get_by_ics_uid(
self, session: AsyncSession, room_id: str, ics_uid: str self, session: AsyncSession, room_id: str, ics_uid: str
@@ -75,7 +77,7 @@ class CalendarEventController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return CalendarEvent(**row.__dict__) return CalendarEvent.model_validate(row)
async def upsert( async def upsert(
self, session: AsyncSession, event: CalendarEvent self, session: AsyncSession, event: CalendarEvent
@@ -137,7 +139,7 @@ class CalendarEventController:
if not include_deleted: if not include_deleted:
query = query.where(CalendarEventModel.is_deleted == False) query = query.where(CalendarEventModel.is_deleted == False)
result = await session.execute(query) result = await session.execute(query)
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
async def get_upcoming( async def get_upcoming(
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120 self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
@@ -159,7 +161,7 @@ class CalendarEventController:
) )
result = await session.execute(query) result = await session.execute(query)
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
async def soft_delete_missing( async def soft_delete_missing(
self, session: AsyncSession, room_id: str, current_ics_uids: list[str] self, session: AsyncSession, room_id: str, current_ics_uids: list[str]

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from typing import Any, Literal from typing import Any, Literal
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, update from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -12,6 +12,8 @@ from reflector.utils import generate_uuid4
class MeetingConsent(BaseModel): class MeetingConsent(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
meeting_id: str meeting_id: str
user_id: str | None = None user_id: str | None = None
@@ -20,6 +22,8 @@ class MeetingConsent(BaseModel):
class Meeting(BaseModel): class Meeting(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
room_name: str room_name: str
room_url: str room_url: str
@@ -76,7 +80,7 @@ class MeetingController:
async def get_all_active(self, session: AsyncSession) -> list[Meeting]: async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
query = select(MeetingModel).where(MeetingModel.is_active) query = select(MeetingModel).where(MeetingModel.is_active)
result = await session.execute(query) result = await session.execute(query)
return [Meeting(**row.__dict__) for row in result.scalars().all()] return [Meeting.model_validate(row) for row in result.scalars().all()]
async def get_by_room_name( async def get_by_room_name(
self, self,
@@ -96,7 +100,7 @@ class MeetingController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Meeting(**row.__dict__) return Meeting.model_validate(row)
async def get_active( async def get_active(
self, session: AsyncSession, room: Room, current_time: datetime self, session: AsyncSession, room: Room, current_time: datetime
@@ -120,7 +124,7 @@ class MeetingController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Meeting(**row.__dict__) return Meeting.model_validate(row)
async def get_all_active_for_room( async def get_all_active_for_room(
self, session: AsyncSession, room: Room, current_time: datetime self, session: AsyncSession, room: Room, current_time: datetime
@@ -137,7 +141,7 @@ class MeetingController:
.order_by(MeetingModel.end_date.desc()) .order_by(MeetingModel.end_date.desc())
) )
result = await session.execute(query) result = await session.execute(query)
return [Meeting(**row.__dict__) for row in result.scalars().all()] return [Meeting.model_validate(row) for row in result.scalars().all()]
async def get_active_by_calendar_event( async def get_active_by_calendar_event(
self, self,
@@ -161,7 +165,7 @@ class MeetingController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Meeting(**row.__dict__) return Meeting.model_validate(row)
async def get_by_id( async def get_by_id(
self, session: AsyncSession, meeting_id: str, **kwargs self, session: AsyncSession, meeting_id: str, **kwargs
@@ -171,7 +175,7 @@ class MeetingController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Meeting(**row.__dict__) return Meeting.model_validate(row)
async def get_by_calendar_event( async def get_by_calendar_event(
self, session: AsyncSession, calendar_event_id: str self, session: AsyncSession, calendar_event_id: str
@@ -183,7 +187,7 @@ class MeetingController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Meeting(**row.__dict__) return Meeting.model_validate(row)
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs): async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
query = ( query = (
@@ -201,7 +205,7 @@ class MeetingConsentController:
MeetingConsentModel.meeting_id == meeting_id MeetingConsentModel.meeting_id == meeting_id
) )
result = await session.execute(query) result = await session.execute(query)
return [MeetingConsent(**row.__dict__) for row in result.scalars().all()] return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
async def get_by_meeting_and_user( async def get_by_meeting_and_user(
self, session: AsyncSession, meeting_id: str, user_id: str self, session: AsyncSession, meeting_id: str, user_id: str
@@ -217,7 +221,7 @@ class MeetingConsentController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if row is None: if row is None:
return None return None
return MeetingConsent(**row.__dict__) return MeetingConsent.model_validate(row)
async def upsert( async def upsert(
self, session: AsyncSession, consent: MeetingConsent self, session: AsyncSession, consent: MeetingConsent

View File

@@ -1,6 +1,6 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import delete, select from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -9,6 +9,8 @@ from reflector.utils import generate_uuid4
class Recording(BaseModel): class Recording(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
meeting_id: str meeting_id: str
url: str url: str
@@ -53,7 +55,7 @@ class RecordingController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Recording(**row.__dict__) return Recording.model_validate(row)
async def get_by_meeting_id( async def get_by_meeting_id(
self, session: AsyncSession, meeting_id: str self, session: AsyncSession, meeting_id: str
@@ -63,7 +65,7 @@ class RecordingController:
""" """
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id) query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
result = await session.execute(query) result = await session.execute(query)
return [Recording(**row.__dict__) for row in result.scalars().all()] return [Recording.model_validate(row) for row in result.scalars().all()]
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None: async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
""" """

View File

@@ -4,7 +4,7 @@ from sqlite3 import IntegrityError
from typing import Literal from typing import Literal
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import delete, select, update from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import or_ from sqlalchemy.sql import or_
@@ -14,6 +14,8 @@ from reflector.utils import generate_uuid4
class Room(BaseModel): class Room(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
name: str name: str
user_id: str user_id: str
@@ -70,7 +72,7 @@ class RoomController:
return query return query
result = await session.execute(query) result = await session.execute(query)
return [Room(**row.__dict__) for row in result.scalars().all()] return [Room.model_validate(row) for row in result.scalars().all()]
async def add( async def add(
self, self,
@@ -155,7 +157,7 @@ class RoomController:
row = result.scalars().first() row = result.scalars().first()
if not row: if not row:
return None return None
return Room(**row.__dict__) return Room.model_validate(row)
async def get_by_name( async def get_by_name(
self, session: AsyncSession, room_name: str, **kwargs self, session: AsyncSession, room_name: str, **kwargs
@@ -170,7 +172,7 @@ class RoomController:
row = result.scalars().first() row = result.scalars().first()
if not row: if not row:
return None return None
return Room(**row.__dict__) return Room.model_validate(row)
async def get_by_id_for_http( async def get_by_id_for_http(
self, session: AsyncSession, meeting_id: str, user_id: str | None self, session: AsyncSession, meeting_id: str, user_id: str | None
@@ -186,7 +188,7 @@ class RoomController:
if not row: if not row:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
room = Room(**row.__dict__) room = Room.model_validate(row)
return room return room

View File

@@ -369,12 +369,10 @@ class SearchController:
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank") rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
columns = base_columns + [rank_column] columns = base_columns + [rank_column]
base_query = sqlalchemy.select(*columns).select_from( base_query = (
TranscriptModel.__table__.join( sqlalchemy.select(*columns)
RoomModel.__table__, .select_from(TranscriptModel)
TranscriptModel.room_id == RoomModel.id, .outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
isouter=True,
)
) )
if params.query_text is not None: if params.query_text is not None:

View File

@@ -103,6 +103,8 @@ class TranscriptParticipant(BaseModel):
class Transcript(BaseModel): class Transcript(BaseModel):
"""Full transcript model with all fields.""" """Full transcript model with all fields."""
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name) name: str = Field(default_factory=generate_transcript_name)
@@ -317,8 +319,9 @@ class TranscriptController:
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%")) query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
# Exclude heavy JSON columns from list queries # Exclude heavy JSON columns from list queries
# Get all ORM column attributes except excluded ones
transcript_columns = [ transcript_columns = [
col getattr(TranscriptModel, col.name)
for col in TranscriptModel.__table__.c for col in TranscriptModel.__table__.c
if col.name not in exclude_columns if col.name not in exclude_columns
] ]
@@ -361,7 +364,7 @@ class TranscriptController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Transcript(**row.__dict__) return Transcript.model_validate(row)
async def get_by_recording_id( async def get_by_recording_id(
self, session: AsyncSession, recording_id: str, **kwargs self, session: AsyncSession, recording_id: str, **kwargs
@@ -378,7 +381,7 @@ class TranscriptController:
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return Transcript(**row.__dict__) return Transcript.model_validate(row)
async def get_by_room_id( async def get_by_room_id(
self, session: AsyncSession, room_id: str, **kwargs self, session: AsyncSession, room_id: str, **kwargs
@@ -396,7 +399,9 @@ class TranscriptController:
field = field.desc() field = field.desc()
query = query.order_by(field) query = query.order_by(field)
results = await session.execute(query) results = await session.execute(query)
return [Transcript(**dict(row)) for row in results.mappings().all()] return [
Transcript.model_validate(dict(row)) for row in results.mappings().all()
]
async def get_by_id_for_http( async def get_by_id_for_http(
self, self,
@@ -420,7 +425,7 @@ class TranscriptController:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
# if the transcript is anonymous, share mode is not checked # if the transcript is anonymous, share mode is not checked
transcript = Transcript(**row.__dict__) transcript = Transcript.model_validate(row)
if transcript.user_id is None: if transcript.user_id is None:
return transcript return transcript

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from sqlalchemy import insert, select, update from sqlalchemy import delete, insert, select, update
from reflector.db.base import ( from reflector.db.base import (
MeetingConsentModel, MeetingConsentModel,
@@ -310,11 +310,9 @@ async def test_meeting_consent_cascade_delete(db_session):
# Delete the transcript and meeting # Delete the transcript and meeting
await db_session.execute( await db_session.execute(
TranscriptModel.__table__.delete().where(TranscriptModel.id == transcript.id) delete(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
await db_session.execute(
MeetingModel.__table__.delete().where(MeetingModel.id == meeting_id)
) )
await db_session.execute(delete(MeetingModel).where(MeetingModel.id == meeting_id))
await db_session.commit() await db_session.commit()
# Verify consent entries were cascade deleted # Verify consent entries were cascade deleted