feat: migrate SQLAlchemy from 1.4 to 2.0 with ORM style

- Remove encode/databases dependency, use native SQLAlchemy 2.0 async
- Convert all table definitions to Declarative Mapping pattern
- Update all controllers to accept session parameter (dependency injection)
- Convert all queries from Core style to ORM style
- Remove PostgreSQL compatibility checks (PostgreSQL only now)
- Add proper typing for engine and session factories
This commit is contained in:
2025-09-18 12:19:53 -06:00
parent 2b723da08b
commit 06639d4d8f
18 changed files with 911 additions and 750 deletions

View File

@@ -7,17 +7,14 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
import sqlalchemy
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_
from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import or_
from reflector.db import get_database, metadata
from reflector.db.base import RoomModel, TranscriptModel
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.logger import logger
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
@@ -32,91 +29,6 @@ class SourceKind(enum.StrEnum):
FILE = enum.auto()
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Float),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("title", sqlalchemy.String),
sqlalchemy.Column("short_summary", sqlalchemy.String),
sqlalchemy.Column("long_summary", sqlalchemy.String),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String),
sqlalchemy.Column("target_language", sqlalchemy.String),
sqlalchemy.Column(
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
sqlalchemy.Column(
"share_mode",
sqlalchemy.String,
nullable=False,
server_default="private",
),
sqlalchemy.Column(
"meeting_id",
sqlalchemy.String,
),
sqlalchemy.Column("recording_id", sqlalchemy.String),
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer),
sqlalchemy.Column(
"source_kind",
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
nullable=False,
),
# indicative field: whether associated audio is deleted
# the main "audio deleted" is the presence of the audio itself / consents not-given
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
)
# Add PostgreSQL-specific full-text search column
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
if is_postgresql():
transcripts.append_column(
sqlalchemy.Column(
"search_vector_en",
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
persisted=True,
),
)
)
# Add GIN index for the search vector
transcripts.append_constraint(
sqlalchemy.Index(
"idx_transcript_search_vector_en",
"search_vector_en",
postgresql_using="gin",
)
)
def generate_transcript_name() -> str:
now = datetime.now(timezone.utc)
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
@@ -359,6 +271,7 @@ class Transcript(BaseModel):
class TranscriptController:
async def get_all(
self,
session: AsyncSession,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = False,
@@ -383,102 +296,111 @@ class TranscriptController:
- `search_term`: filter transcripts by search term
"""
query = transcripts.select().join(
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
query = select(TranscriptModel).join(
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True
)
if user_id:
query = query.where(
or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared)
)
else:
query = query.where(rooms.c.is_shared)
query = query.where(RoomModel.is_shared)
if source_kind:
query = query.where(transcripts.c.source_kind == source_kind)
query = query.where(TranscriptModel.source_kind == source_kind)
if room_id:
query = query.where(transcripts.c.room_id == room_id)
query = query.where(TranscriptModel.room_id == room_id)
if search_term:
query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
# Exclude heavy JSON columns from list queries
transcript_columns = [
col for col in transcripts.c if col.name not in exclude_columns
col
for col in TranscriptModel.__table__.c
if col.name not in exclude_columns
]
query = query.with_only_columns(
transcript_columns
+ [
rooms.c.name.label("room_name"),
]
*transcript_columns,
RoomModel.name.label("room_name"),
)
if order_by is not None:
field = getattr(transcripts.c, order_by[1:])
field = getattr(TranscriptModel, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(transcripts.c.status != "idle")
query = query.filter(TranscriptModel.status != "idle")
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
query = query.filter(TranscriptModel.status != "recording")
# print(query.compile(compile_kwargs={"literal_binds": True}))
if return_query:
return query
results = await get_database().fetch_all(query)
return results
result = await session.execute(query)
return [dict(row) for row in result.mappings().all()]
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
async def get_by_id(
self, session: AsyncSession, transcript_id: str, **kwargs
) -> Transcript | None:
"""
Get a transcript by id
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await get_database().fetch_one(query)
if not result:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return Transcript(**result)
return Transcript(**row.__dict__)
async def get_by_recording_id(
self, recording_id: str, **kwargs
self, session: AsyncSession, recording_id: str, **kwargs
) -> Transcript | None:
"""
Get a transcript by recording_id
"""
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
query = select(TranscriptModel).where(
TranscriptModel.recording_id == recording_id
)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await get_database().fetch_one(query)
if not result:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return Transcript(**result)
return Transcript(**row.__dict__)
async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
async def get_by_room_id(
self, session: AsyncSession, room_id: str, **kwargs
) -> list[Transcript]:
"""
Get transcripts by room_id (direct access without joins)
"""
query = transcripts.select().where(transcripts.c.room_id == room_id)
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
if "order_by" in kwargs:
order_by = kwargs["order_by"]
field = getattr(transcripts.c, order_by[1:])
field = getattr(TranscriptModel, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
results = await get_database().fetch_all(query)
return [Transcript(**result) for result in results]
results = await session.execute(query)
return [Transcript(**dict(row)) for row in results.mappings().all()]
async def get_by_id_for_http(
self,
session: AsyncSession,
transcript_id: str,
user_id: str | None,
) -> Transcript:
@@ -491,13 +413,14 @@ class TranscriptController:
This method checks the share mode of the transcript and the user_id
to determine if the user can access the transcript.
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await get_database().fetch_one(query)
if not result:
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
raise HTTPException(status_code=404, detail="Transcript not found")
# if the transcript is anonymous, share mode is not checked
transcript = Transcript(**result)
transcript = Transcript(**row.__dict__)
if transcript.user_id is None:
return transcript
@@ -520,6 +443,7 @@ class TranscriptController:
async def add(
self,
session: AsyncSession,
name: str,
source_kind: SourceKind,
source_language: str = "en",
@@ -544,14 +468,15 @@ class TranscriptController:
meeting_id=meeting_id,
room_id=room_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await get_database().execute(query)
query = insert(TranscriptModel).values(**transcript.model_dump())
await session.execute(query)
await session.commit()
return transcript
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self, transcript: Transcript, values: dict, mutate=False
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False
) -> Transcript:
"""
Update a transcript fields with key/values in values.
@@ -560,11 +485,12 @@ class TranscriptController:
values = TranscriptController._handle_topics_update(values)
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
.values(**values)
)
await get_database().execute(query)
await session.execute(query)
await session.commit()
if mutate:
for key, value in values.items():
setattr(transcript, key, value)
@@ -593,13 +519,14 @@ class TranscriptController:
async def remove_by_id(
self,
session: AsyncSession,
transcript_id: str,
user_id: str | None = None,
) -> None:
"""
Remove a transcript by id
"""
transcript = await self.get_by_id(transcript_id)
transcript = await self.get_by_id(session, transcript_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
@@ -619,7 +546,7 @@ class TranscriptController:
if transcript.recording_id:
try:
recording = await recordings_controller.get_by_id(
transcript.recording_id
session, transcript.recording_id
)
if recording:
try:
@@ -630,33 +557,40 @@ class TranscriptController:
exc_info=e,
recording_id=transcript.recording_id,
)
await recordings_controller.remove_by_id(transcript.recording_id)
await recordings_controller.remove_by_id(
session, transcript.recording_id
)
except Exception as e:
logger.warning(
"Failed to delete recording row",
exc_info=e,
recording_id=transcript.recording_id,
)
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await get_database().execute(query)
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id)
await session.execute(query)
await session.commit()
async def remove_by_recording_id(self, recording_id: str):
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str):
"""
Remove a transcript by recording_id
"""
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
await get_database().execute(query)
query = delete(TranscriptModel).where(
TranscriptModel.recording_id == recording_id
)
await session.execute(query)
await session.commit()
@asynccontextmanager
async def transaction(self):
async def transaction(self, session: AsyncSession):
"""
A context manager for database transaction
"""
async with get_database().transaction(isolation="serializable"):
async with session.begin():
yield
async def append_event(
self,
session: AsyncSession,
transcript: Transcript,
event: str,
data: Any,
@@ -665,11 +599,12 @@ class TranscriptController:
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(transcript, {"events": transcript.events_dump()})
await self.update(session, transcript, {"events": transcript.events_dump()})
return resp
async def upsert_topic(
self,
session: AsyncSession,
transcript: Transcript,
topic: TranscriptTopic,
) -> TranscriptEvent:
@@ -677,9 +612,9 @@ class TranscriptController:
Upsert topics to a transcript
"""
transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()})
await self.update(session, transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript):
"""
Move mp3 file to storage
"""
@@ -703,12 +638,16 @@ class TranscriptController:
# indicate on the transcript that the audio is now on storage
# mutates transcript argument
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
await self.update(
session, transcript, {"audio_location": "storage"}, mutate=True
)
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
async def download_mp3_from_storage(self, transcript: Transcript):
async def download_mp3_from_storage(
self, session: AsyncSession, transcript: Transcript
):
"""
Download audio from storage
"""
@@ -720,6 +659,7 @@ class TranscriptController:
async def upsert_participant(
self,
session: AsyncSession,
transcript: Transcript,
participant: TranscriptParticipant,
) -> TranscriptParticipant:
@@ -727,11 +667,14 @@ class TranscriptController:
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(transcript, {"participants": transcript.participants_dump()})
await self.update(
session, transcript, {"participants": transcript.participants_dump()}
)
return result
async def delete_participant(
self,
session: AsyncSession,
transcript: Transcript,
participant_id: str,
):
@@ -739,28 +682,31 @@ class TranscriptController:
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(transcript, {"participants": transcript.participants_dump()})
await self.update(
session, transcript, {"participants": transcript.participants_dump()}
)
async def set_status(
self, transcript_id: str, status: TranscriptStatus
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus
) -> TranscriptEvent | None:
"""
Update the status of a transcript
Will add an event STATUS + update the status field of transcript
"""
async with self.transaction():
transcript = await self.get_by_id(transcript_id)
async with self.transaction(session):
transcript = await self.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
if transcript.status == status:
return
resp = await self.append_event(
session,
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await self.update(transcript, {"status": status})
await self.update(session, transcript, {"status": status})
return resp