mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
refactor: migrate to SQLAlchemy 2.0 ORM-style patterns
- Replace __table__.join() with ORM-style joins using select_from().outerjoin() - Replace __table__.delete() with delete(Model) in tests - Migrate from **row.__dict__ to model_validate() with ConfigDict(from_attributes=True) - Add ConfigDict(from_attributes=True) to all Pydantic models for proper SQLAlchemy model conversion - Update all controller methods to use model_validate() instead of dict unpacking This completes the migration to SQLAlchemy 2.0 recommended patterns while maintaining backwards compatibility and improving code consistency.
This commit is contained in:
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -11,6 +11,8 @@ from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
@@ -50,7 +52,7 @@ class CalendarEventController:
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()]
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, event_id: str
|
||||
@@ -60,7 +62,7 @@ class CalendarEventController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return CalendarEvent(**row.__dict__)
|
||||
return CalendarEvent.model_validate(row)
|
||||
|
||||
async def get_by_ics_uid(
|
||||
self, session: AsyncSession, room_id: str, ics_uid: str
|
||||
@@ -75,7 +77,7 @@ class CalendarEventController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return CalendarEvent(**row.__dict__)
|
||||
return CalendarEvent.model_validate(row)
|
||||
|
||||
async def upsert(
|
||||
self, session: AsyncSession, event: CalendarEvent
|
||||
@@ -137,7 +139,7 @@ class CalendarEventController:
|
||||
if not include_deleted:
|
||||
query = query.where(CalendarEventModel.is_deleted == False)
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()]
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_upcoming(
|
||||
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
|
||||
@@ -159,7 +161,7 @@ class CalendarEventController:
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()]
|
||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def soft_delete_missing(
|
||||
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -12,6 +12,8 @@ from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
class MeetingConsent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
meeting_id: str
|
||||
user_id: str | None = None
|
||||
@@ -20,6 +22,8 @@ class MeetingConsent(BaseModel):
|
||||
|
||||
|
||||
class Meeting(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
room_name: str
|
||||
room_url: str
|
||||
@@ -76,7 +80,7 @@ class MeetingController:
|
||||
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
|
||||
query = select(MeetingModel).where(MeetingModel.is_active)
|
||||
result = await session.execute(query)
|
||||
return [Meeting(**row.__dict__) for row in result.scalars().all()]
|
||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_room_name(
|
||||
self,
|
||||
@@ -96,7 +100,7 @@ class MeetingController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting(**row.__dict__)
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
async def get_active(
|
||||
self, session: AsyncSession, room: Room, current_time: datetime
|
||||
@@ -120,7 +124,7 @@ class MeetingController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting(**row.__dict__)
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
async def get_all_active_for_room(
|
||||
self, session: AsyncSession, room: Room, current_time: datetime
|
||||
@@ -137,7 +141,7 @@ class MeetingController:
|
||||
.order_by(MeetingModel.end_date.desc())
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return [Meeting(**row.__dict__) for row in result.scalars().all()]
|
||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_active_by_calendar_event(
|
||||
self,
|
||||
@@ -161,7 +165,7 @@ class MeetingController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting(**row.__dict__)
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
async def get_by_id(
|
||||
self, session: AsyncSession, meeting_id: str, **kwargs
|
||||
@@ -171,7 +175,7 @@ class MeetingController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting(**row.__dict__)
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
async def get_by_calendar_event(
|
||||
self, session: AsyncSession, calendar_event_id: str
|
||||
@@ -183,7 +187,7 @@ class MeetingController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Meeting(**row.__dict__)
|
||||
return Meeting.model_validate(row)
|
||||
|
||||
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
|
||||
query = (
|
||||
@@ -201,7 +205,7 @@ class MeetingConsentController:
|
||||
MeetingConsentModel.meeting_id == meeting_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return [MeetingConsent(**row.__dict__) for row in result.scalars().all()]
|
||||
return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def get_by_meeting_and_user(
|
||||
self, session: AsyncSession, meeting_id: str, user_id: str
|
||||
@@ -217,7 +221,7 @@ class MeetingConsentController:
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
return MeetingConsent(**row.__dict__)
|
||||
return MeetingConsent.model_validate(row)
|
||||
|
||||
async def upsert(
|
||||
self, session: AsyncSession, consent: MeetingConsent
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -9,6 +9,8 @@ from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
class Recording(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
meeting_id: str
|
||||
url: str
|
||||
@@ -53,7 +55,7 @@ class RecordingController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Recording(**row.__dict__)
|
||||
return Recording.model_validate(row)
|
||||
|
||||
async def get_by_meeting_id(
|
||||
self, session: AsyncSession, meeting_id: str
|
||||
@@ -63,7 +65,7 @@ class RecordingController:
|
||||
"""
|
||||
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
|
||||
result = await session.execute(query)
|
||||
return [Recording(**row.__dict__) for row in result.scalars().all()]
|
||||
return [Recording.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from sqlite3 import IntegrityError
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import or_
|
||||
@@ -14,6 +14,8 @@ from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
class Room(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
name: str
|
||||
user_id: str
|
||||
@@ -70,7 +72,7 @@ class RoomController:
|
||||
return query
|
||||
|
||||
result = await session.execute(query)
|
||||
return [Room(**row.__dict__) for row in result.scalars().all()]
|
||||
return [Room.model_validate(row) for row in result.scalars().all()]
|
||||
|
||||
async def add(
|
||||
self,
|
||||
@@ -155,7 +157,7 @@ class RoomController:
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
return None
|
||||
return Room(**row.__dict__)
|
||||
return Room.model_validate(row)
|
||||
|
||||
async def get_by_name(
|
||||
self, session: AsyncSession, room_name: str, **kwargs
|
||||
@@ -170,7 +172,7 @@ class RoomController:
|
||||
row = result.scalars().first()
|
||||
if not row:
|
||||
return None
|
||||
return Room(**row.__dict__)
|
||||
return Room.model_validate(row)
|
||||
|
||||
async def get_by_id_for_http(
|
||||
self, session: AsyncSession, meeting_id: str, user_id: str | None
|
||||
@@ -186,7 +188,7 @@ class RoomController:
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
room = Room(**row.__dict__)
|
||||
room = Room.model_validate(row)
|
||||
|
||||
return room
|
||||
|
||||
|
||||
@@ -369,12 +369,10 @@ class SearchController:
|
||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||
|
||||
columns = base_columns + [rank_column]
|
||||
base_query = sqlalchemy.select(*columns).select_from(
|
||||
TranscriptModel.__table__.join(
|
||||
RoomModel.__table__,
|
||||
TranscriptModel.room_id == RoomModel.id,
|
||||
isouter=True,
|
||||
)
|
||||
base_query = (
|
||||
sqlalchemy.select(*columns)
|
||||
.select_from(TranscriptModel)
|
||||
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
|
||||
)
|
||||
|
||||
if params.query_text is not None:
|
||||
|
||||
@@ -103,6 +103,8 @@ class TranscriptParticipant(BaseModel):
|
||||
class Transcript(BaseModel):
|
||||
"""Full transcript model with all fields."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
@@ -317,8 +319,9 @@ class TranscriptController:
|
||||
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
|
||||
|
||||
# Exclude heavy JSON columns from list queries
|
||||
# Get all ORM column attributes except excluded ones
|
||||
transcript_columns = [
|
||||
col
|
||||
getattr(TranscriptModel, col.name)
|
||||
for col in TranscriptModel.__table__.c
|
||||
if col.name not in exclude_columns
|
||||
]
|
||||
@@ -361,7 +364,7 @@ class TranscriptController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Transcript(**row.__dict__)
|
||||
return Transcript.model_validate(row)
|
||||
|
||||
async def get_by_recording_id(
|
||||
self, session: AsyncSession, recording_id: str, **kwargs
|
||||
@@ -378,7 +381,7 @@ class TranscriptController:
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return Transcript(**row.__dict__)
|
||||
return Transcript.model_validate(row)
|
||||
|
||||
async def get_by_room_id(
|
||||
self, session: AsyncSession, room_id: str, **kwargs
|
||||
@@ -396,7 +399,9 @@ class TranscriptController:
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
results = await session.execute(query)
|
||||
return [Transcript(**dict(row)) for row in results.mappings().all()]
|
||||
return [
|
||||
Transcript.model_validate(dict(row)) for row in results.mappings().all()
|
||||
]
|
||||
|
||||
async def get_by_id_for_http(
|
||||
self,
|
||||
@@ -420,7 +425,7 @@ class TranscriptController:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# if the transcript is anonymous, share mode is not checked
|
||||
transcript = Transcript(**row.__dict__)
|
||||
transcript = Transcript.model_validate(row)
|
||||
if transcript.user_id is None:
|
||||
return transcript
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import insert, select, update
|
||||
from sqlalchemy import delete, insert, select, update
|
||||
|
||||
from reflector.db.base import (
|
||||
MeetingConsentModel,
|
||||
@@ -310,11 +310,9 @@ async def test_meeting_consent_cascade_delete(db_session):
|
||||
|
||||
# Delete the transcript and meeting
|
||||
await db_session.execute(
|
||||
TranscriptModel.__table__.delete().where(TranscriptModel.id == transcript.id)
|
||||
)
|
||||
await db_session.execute(
|
||||
MeetingModel.__table__.delete().where(MeetingModel.id == meeting_id)
|
||||
delete(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
||||
)
|
||||
await db_session.execute(delete(MeetingModel).where(MeetingModel.id == meeting_id))
|
||||
await db_session.commit()
|
||||
|
||||
# Verify consent entries were cascade deleted
|
||||
|
||||
Reference in New Issue
Block a user