feat: migrate SQLAlchemy from 1.4 to 2.0 with ORM style

- Remove encode/databases dependency, use native SQLAlchemy 2.0 async
- Convert all table definitions to Declarative Mapping pattern
- Update all controllers to accept session parameter (dependency injection)
- Convert all queries from Core style to ORM style
- Remove PostgreSQL compatibility checks (PostgreSQL only now)
- Add proper typing for engine and session factories
This commit is contained in:
2025-09-18 12:19:53 -06:00
parent 2b723da08b
commit 06639d4d8f
18 changed files with 911 additions and 750 deletions

View File

@@ -8,7 +8,6 @@ from typing import Annotated, Any, Dict, Iterator
import sqlalchemy
import webvtt
from databases.interfaces import Record as DbRecord
from fastapi import HTTPException
from pydantic import (
BaseModel,
@@ -20,11 +19,10 @@ from pydantic import (
constr,
field_serializer,
)
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_database
from reflector.db.rooms import rooms
from reflector.db.transcripts import SourceKind, TranscriptStatus, transcripts
from reflector.db.utils import is_postgresql
from reflector.db.base import RoomModel, TranscriptModel
from reflector.db.transcripts import SourceKind, TranscriptStatus
from reflector.logger import logger
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
@@ -331,36 +329,30 @@ class SearchController:
@classmethod
async def search_transcripts(
cls, params: SearchParameters
cls, session: AsyncSession, params: SearchParameters
) -> tuple[list[SearchResult], int]:
"""
Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count).
"""
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
base_columns = [
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
transcripts.c.long_summary,
TranscriptModel.id,
TranscriptModel.title,
TranscriptModel.created_at,
TranscriptModel.duration,
TranscriptModel.status,
TranscriptModel.user_id,
TranscriptModel.room_id,
TranscriptModel.source_kind,
TranscriptModel.webvtt,
TranscriptModel.long_summary,
sqlalchemy.case(
(
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None),
"Deleted Room",
),
else_=rooms.c.name,
else_=RoomModel.name,
).label("room_name"),
]
search_query = None
@@ -369,7 +361,7 @@ class SearchController:
"english", params.query_text
)
rank_column = sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
TranscriptModel.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank")
@@ -378,46 +370,52 @@ class SearchController:
columns = base_columns + [rank_column]
base_query = sqlalchemy.select(columns).select_from(
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
TranscriptModel.__table__.join(
RoomModel.__table__,
TranscriptModel.room_id == RoomModel.id,
isouter=True,
)
)
if params.query_text is not None:
# because already initialized based on params.query_text presence above
assert search_query is not None
base_query = base_query.where(
transcripts.c.search_vector_en.op("@@")(search_query)
TranscriptModel.search_vector_en.op("@@")(search_query)
)
if params.user_id:
base_query = base_query.where(
sqlalchemy.or_(
transcripts.c.user_id == params.user_id, rooms.c.is_shared
TranscriptModel.user_id == params.user_id, RoomModel.is_shared
)
)
else:
base_query = base_query.where(rooms.c.is_shared)
base_query = base_query.where(RoomModel.is_shared)
if params.room_id:
base_query = base_query.where(transcripts.c.room_id == params.room_id)
base_query = base_query.where(TranscriptModel.room_id == params.room_id)
if params.source_kind:
base_query = base_query.where(
transcripts.c.source_kind == params.source_kind
TranscriptModel.source_kind == params.source_kind
)
if params.query_text is not None:
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
else:
order_by = sqlalchemy.desc(transcripts.c.created_at)
order_by = sqlalchemy.desc(TranscriptModel.created_at)
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
rs = await get_database().fetch_all(query)
result = await session.execute(query)
rs = result.mappings().all()
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
total = await get_database().fetch_val(count_query)
count_result = await session.execute(count_query)
total = count_result.scalar()
def _process_result(r: DbRecord) -> SearchResult:
def _process_result(r: dict) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt_raw: str | None = r_dict.pop("webvtt", None)