mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
refactor: migrate to SQLAlchemy 2.0 ORM-style patterns
- Replace __table__.join() with ORM-style joins using select_from().outerjoin() - Replace __table__.delete() with delete(Model) in tests - Migrate from **row.__dict__ to model_validate() with ConfigDict(from_attributes=True) - Add ConfigDict(from_attributes=True) to all Pydantic models for proper SQLAlchemy model conversion - Update all controller methods to use model_validate() instead of dict unpacking This completes the migration to SQLAlchemy 2.0 recommended patterns while maintaining backwards compatibility and improving code consistency.
This commit is contained in:
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy import delete, select, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -11,6 +11,8 @@ from reflector.utils import generate_uuid4
|
|||||||
|
|
||||||
|
|
||||||
class CalendarEvent(BaseModel):
|
class CalendarEvent(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
room_id: str
|
room_id: str
|
||||||
ics_uid: str
|
ics_uid: str
|
||||||
@@ -50,7 +52,7 @@ class CalendarEventController:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()]
|
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(
|
||||||
self, session: AsyncSession, event_id: str
|
self, session: AsyncSession, event_id: str
|
||||||
@@ -60,7 +62,7 @@ class CalendarEventController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return CalendarEvent(**row.__dict__)
|
return CalendarEvent.model_validate(row)
|
||||||
|
|
||||||
async def get_by_ics_uid(
|
async def get_by_ics_uid(
|
||||||
self, session: AsyncSession, room_id: str, ics_uid: str
|
self, session: AsyncSession, room_id: str, ics_uid: str
|
||||||
@@ -75,7 +77,7 @@ class CalendarEventController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return CalendarEvent(**row.__dict__)
|
return CalendarEvent.model_validate(row)
|
||||||
|
|
||||||
async def upsert(
|
async def upsert(
|
||||||
self, session: AsyncSession, event: CalendarEvent
|
self, session: AsyncSession, event: CalendarEvent
|
||||||
@@ -137,7 +139,7 @@ class CalendarEventController:
|
|||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
query = query.where(CalendarEventModel.is_deleted == False)
|
query = query.where(CalendarEventModel.is_deleted == False)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()]
|
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def get_upcoming(
|
async def get_upcoming(
|
||||||
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
|
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
|
||||||
@@ -159,7 +161,7 @@ class CalendarEventController:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [CalendarEvent(**row.__dict__) for row in result.scalars().all()]
|
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def soft_delete_missing(
|
async def soft_delete_missing(
|
||||||
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
|
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import select, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -12,6 +12,8 @@ from reflector.utils import generate_uuid4
|
|||||||
|
|
||||||
|
|
||||||
class MeetingConsent(BaseModel):
|
class MeetingConsent(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
meeting_id: str
|
meeting_id: str
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
@@ -20,6 +22,8 @@ class MeetingConsent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Meeting(BaseModel):
|
class Meeting(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
room_name: str
|
room_name: str
|
||||||
room_url: str
|
room_url: str
|
||||||
@@ -76,7 +80,7 @@ class MeetingController:
|
|||||||
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
|
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
|
||||||
query = select(MeetingModel).where(MeetingModel.is_active)
|
query = select(MeetingModel).where(MeetingModel.is_active)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [Meeting(**row.__dict__) for row in result.scalars().all()]
|
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def get_by_room_name(
|
async def get_by_room_name(
|
||||||
self,
|
self,
|
||||||
@@ -96,7 +100,7 @@ class MeetingController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Meeting(**row.__dict__)
|
return Meeting.model_validate(row)
|
||||||
|
|
||||||
async def get_active(
|
async def get_active(
|
||||||
self, session: AsyncSession, room: Room, current_time: datetime
|
self, session: AsyncSession, room: Room, current_time: datetime
|
||||||
@@ -120,7 +124,7 @@ class MeetingController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Meeting(**row.__dict__)
|
return Meeting.model_validate(row)
|
||||||
|
|
||||||
async def get_all_active_for_room(
|
async def get_all_active_for_room(
|
||||||
self, session: AsyncSession, room: Room, current_time: datetime
|
self, session: AsyncSession, room: Room, current_time: datetime
|
||||||
@@ -137,7 +141,7 @@ class MeetingController:
|
|||||||
.order_by(MeetingModel.end_date.desc())
|
.order_by(MeetingModel.end_date.desc())
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [Meeting(**row.__dict__) for row in result.scalars().all()]
|
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def get_active_by_calendar_event(
|
async def get_active_by_calendar_event(
|
||||||
self,
|
self,
|
||||||
@@ -161,7 +165,7 @@ class MeetingController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Meeting(**row.__dict__)
|
return Meeting.model_validate(row)
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(
|
||||||
self, session: AsyncSession, meeting_id: str, **kwargs
|
self, session: AsyncSession, meeting_id: str, **kwargs
|
||||||
@@ -171,7 +175,7 @@ class MeetingController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Meeting(**row.__dict__)
|
return Meeting.model_validate(row)
|
||||||
|
|
||||||
async def get_by_calendar_event(
|
async def get_by_calendar_event(
|
||||||
self, session: AsyncSession, calendar_event_id: str
|
self, session: AsyncSession, calendar_event_id: str
|
||||||
@@ -183,7 +187,7 @@ class MeetingController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Meeting(**row.__dict__)
|
return Meeting.model_validate(row)
|
||||||
|
|
||||||
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
|
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
|
||||||
query = (
|
query = (
|
||||||
@@ -201,7 +205,7 @@ class MeetingConsentController:
|
|||||||
MeetingConsentModel.meeting_id == meeting_id
|
MeetingConsentModel.meeting_id == meeting_id
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [MeetingConsent(**row.__dict__) for row in result.scalars().all()]
|
return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def get_by_meeting_and_user(
|
async def get_by_meeting_and_user(
|
||||||
self, session: AsyncSession, meeting_id: str, user_id: str
|
self, session: AsyncSession, meeting_id: str, user_id: str
|
||||||
@@ -217,7 +221,7 @@ class MeetingConsentController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
return MeetingConsent(**row.__dict__)
|
return MeetingConsent.model_validate(row)
|
||||||
|
|
||||||
async def upsert(
|
async def upsert(
|
||||||
self, session: AsyncSession, consent: MeetingConsent
|
self, session: AsyncSession, consent: MeetingConsent
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -9,6 +9,8 @@ from reflector.utils import generate_uuid4
|
|||||||
|
|
||||||
|
|
||||||
class Recording(BaseModel):
|
class Recording(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
meeting_id: str
|
meeting_id: str
|
||||||
url: str
|
url: str
|
||||||
@@ -53,7 +55,7 @@ class RecordingController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Recording(**row.__dict__)
|
return Recording.model_validate(row)
|
||||||
|
|
||||||
async def get_by_meeting_id(
|
async def get_by_meeting_id(
|
||||||
self, session: AsyncSession, meeting_id: str
|
self, session: AsyncSession, meeting_id: str
|
||||||
@@ -63,7 +65,7 @@ class RecordingController:
|
|||||||
"""
|
"""
|
||||||
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
|
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [Recording(**row.__dict__) for row in result.scalars().all()]
|
return [Recording.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
|
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from sqlite3 import IntegrityError
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy import delete, select, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.sql import or_
|
from sqlalchemy.sql import or_
|
||||||
@@ -14,6 +14,8 @@ from reflector.utils import generate_uuid4
|
|||||||
|
|
||||||
|
|
||||||
class Room(BaseModel):
|
class Room(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
name: str
|
name: str
|
||||||
user_id: str
|
user_id: str
|
||||||
@@ -70,7 +72,7 @@ class RoomController:
|
|||||||
return query
|
return query
|
||||||
|
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return [Room(**row.__dict__) for row in result.scalars().all()]
|
return [Room.model_validate(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self,
|
||||||
@@ -155,7 +157,7 @@ class RoomController:
|
|||||||
row = result.scalars().first()
|
row = result.scalars().first()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Room(**row.__dict__)
|
return Room.model_validate(row)
|
||||||
|
|
||||||
async def get_by_name(
|
async def get_by_name(
|
||||||
self, session: AsyncSession, room_name: str, **kwargs
|
self, session: AsyncSession, room_name: str, **kwargs
|
||||||
@@ -170,7 +172,7 @@ class RoomController:
|
|||||||
row = result.scalars().first()
|
row = result.scalars().first()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Room(**row.__dict__)
|
return Room.model_validate(row)
|
||||||
|
|
||||||
async def get_by_id_for_http(
|
async def get_by_id_for_http(
|
||||||
self, session: AsyncSession, meeting_id: str, user_id: str | None
|
self, session: AsyncSession, meeting_id: str, user_id: str | None
|
||||||
@@ -186,7 +188,7 @@ class RoomController:
|
|||||||
if not row:
|
if not row:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
room = Room(**row.__dict__)
|
room = Room.model_validate(row)
|
||||||
|
|
||||||
return room
|
return room
|
||||||
|
|
||||||
|
|||||||
@@ -369,12 +369,10 @@ class SearchController:
|
|||||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||||
|
|
||||||
columns = base_columns + [rank_column]
|
columns = base_columns + [rank_column]
|
||||||
base_query = sqlalchemy.select(*columns).select_from(
|
base_query = (
|
||||||
TranscriptModel.__table__.join(
|
sqlalchemy.select(*columns)
|
||||||
RoomModel.__table__,
|
.select_from(TranscriptModel)
|
||||||
TranscriptModel.room_id == RoomModel.id,
|
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
|
||||||
isouter=True,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.query_text is not None:
|
if params.query_text is not None:
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ class TranscriptParticipant(BaseModel):
|
|||||||
class Transcript(BaseModel):
|
class Transcript(BaseModel):
|
||||||
"""Full transcript model with all fields."""
|
"""Full transcript model with all fields."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
name: str = Field(default_factory=generate_transcript_name)
|
name: str = Field(default_factory=generate_transcript_name)
|
||||||
@@ -317,8 +319,9 @@ class TranscriptController:
|
|||||||
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
|
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
|
||||||
|
|
||||||
# Exclude heavy JSON columns from list queries
|
# Exclude heavy JSON columns from list queries
|
||||||
|
# Get all ORM column attributes except excluded ones
|
||||||
transcript_columns = [
|
transcript_columns = [
|
||||||
col
|
getattr(TranscriptModel, col.name)
|
||||||
for col in TranscriptModel.__table__.c
|
for col in TranscriptModel.__table__.c
|
||||||
if col.name not in exclude_columns
|
if col.name not in exclude_columns
|
||||||
]
|
]
|
||||||
@@ -361,7 +364,7 @@ class TranscriptController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Transcript(**row.__dict__)
|
return Transcript.model_validate(row)
|
||||||
|
|
||||||
async def get_by_recording_id(
|
async def get_by_recording_id(
|
||||||
self, session: AsyncSession, recording_id: str, **kwargs
|
self, session: AsyncSession, recording_id: str, **kwargs
|
||||||
@@ -378,7 +381,7 @@ class TranscriptController:
|
|||||||
row = result.scalar_one_or_none()
|
row = result.scalar_one_or_none()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
return Transcript(**row.__dict__)
|
return Transcript.model_validate(row)
|
||||||
|
|
||||||
async def get_by_room_id(
|
async def get_by_room_id(
|
||||||
self, session: AsyncSession, room_id: str, **kwargs
|
self, session: AsyncSession, room_id: str, **kwargs
|
||||||
@@ -396,7 +399,9 @@ class TranscriptController:
|
|||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
results = await session.execute(query)
|
results = await session.execute(query)
|
||||||
return [Transcript(**dict(row)) for row in results.mappings().all()]
|
return [
|
||||||
|
Transcript.model_validate(dict(row)) for row in results.mappings().all()
|
||||||
|
]
|
||||||
|
|
||||||
async def get_by_id_for_http(
|
async def get_by_id_for_http(
|
||||||
self,
|
self,
|
||||||
@@ -420,7 +425,7 @@ class TranscriptController:
|
|||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
# if the transcript is anonymous, share mode is not checked
|
# if the transcript is anonymous, share mode is not checked
|
||||||
transcript = Transcript(**row.__dict__)
|
transcript = Transcript.model_validate(row)
|
||||||
if transcript.user_id is None:
|
if transcript.user_id is None:
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import insert, select, update
|
from sqlalchemy import delete, insert, select, update
|
||||||
|
|
||||||
from reflector.db.base import (
|
from reflector.db.base import (
|
||||||
MeetingConsentModel,
|
MeetingConsentModel,
|
||||||
@@ -310,11 +310,9 @@ async def test_meeting_consent_cascade_delete(db_session):
|
|||||||
|
|
||||||
# Delete the transcript and meeting
|
# Delete the transcript and meeting
|
||||||
await db_session.execute(
|
await db_session.execute(
|
||||||
TranscriptModel.__table__.delete().where(TranscriptModel.id == transcript.id)
|
delete(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
||||||
)
|
|
||||||
await db_session.execute(
|
|
||||||
MeetingModel.__table__.delete().where(MeetingModel.id == meeting_id)
|
|
||||||
)
|
)
|
||||||
|
await db_session.execute(delete(MeetingModel).where(MeetingModel.id == meeting_id))
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Verify consent entries were cascade deleted
|
# Verify consent entries were cascade deleted
|
||||||
|
|||||||
Reference in New Issue
Block a user