diff --git a/server/migrations/env.py b/server/migrations/env.py index 7a592836..0c532787 100644 --- a/server/migrations/env.py +++ b/server/migrations/env.py @@ -3,7 +3,7 @@ from logging.config import fileConfig from alembic import context from sqlalchemy import engine_from_config, pool -from reflector.db import metadata +from reflector.db.base import metadata from reflector.settings import settings # this is the Alembic Config object, which provides @@ -25,8 +25,7 @@ target_metadata = metadata # ... etc. -# don't use asyncpg for the moment -settings.DATABASE_URL = settings.DATABASE_URL.replace("+asyncpg", "") +# No need to modify URL, using sync engine from db module def run_migrations_offline() -> None: diff --git a/server/pyproject.toml b/server/pyproject.toml index f63947c8..269f389f 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -19,8 +19,8 @@ dependencies = [ "sentry-sdk[fastapi]>=1.29.2", "httpx>=0.24.1", "fastapi-pagination>=0.12.6", - "databases[aiosqlite, asyncpg]>=0.7.0", - "sqlalchemy<1.5", + "sqlalchemy>=2.0.0", + "asyncpg>=0.29.0", "alembic>=1.11.3", "nltk>=3.8.1", "prometheus-fastapi-instrumentator>=6.1.0", @@ -111,7 +111,7 @@ source = ["reflector"] [tool.pytest_env] ENVIRONMENT = "pytest" -DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test" +DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/reflector_test" [tool.pytest.ini_options] addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v" diff --git a/server/reflector/asynctask.py b/server/reflector/asynctask.py index 61523a6f..409a04f0 100644 --- a/server/reflector/asynctask.py +++ b/server/reflector/asynctask.py @@ -1,21 +1,14 @@ import asyncio import functools -from reflector.db import get_database - def asynctask(f): @functools.wraps(f) def wrapper(*args, **kwargs): - async def run_with_db(): - database = get_database() - await database.connect() - try: - return await f(*args, **kwargs) - finally: - await database.disconnect() + async def run_async(): + return await f(*args, **kwargs) - coro = run_with_db() + coro = run_async() try: loop = asyncio.get_running_loop() except RuntimeError: diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index f79a2573..7f2a58a3 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -1,48 +1,63 @@ -import contextvars -from typing import Optional +from typing import AsyncGenerator -import databases -import sqlalchemy +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from reflector.db.base import Base as Base +from reflector.db.base import metadata as metadata from reflector.events import subscribers_shutdown, subscribers_startup from reflector.settings import settings -metadata = sqlalchemy.MetaData() - -_database_context: contextvars.ContextVar[Optional[databases.Database]] = ( - contextvars.ContextVar("database", default=None) -) +_engine: AsyncEngine | None = None +_session_factory: async_sessionmaker[AsyncSession] | None = None -def get_database() -> databases.Database: - """Get database instance for current asyncio context""" - db = _database_context.get() - if db is None: - db = databases.Database(settings.DATABASE_URL) - _database_context.set(db) - return db +def get_engine() -> AsyncEngine: + global _engine + if _engine is None: + _engine = create_async_engine( + settings.DATABASE_URL, + echo=False, + pool_pre_ping=True, + ) + return _engine + + +def get_session_factory() -> async_sessionmaker[AsyncSession]: + global _session_factory + if _session_factory is None: + _session_factory = async_sessionmaker( + get_engine(), + class_=AsyncSession, + expire_on_commit=False, + ) + return _session_factory + + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + async with get_session_factory()() as session: + yield session -# import models import reflector.db.calendar_events # noqa import reflector.db.meetings # noqa import reflector.db.recordings # noqa import reflector.db.rooms # noqa import reflector.db.transcripts # noqa -kwargs = {} -if "postgres" not in settings.DATABASE_URL: - raise Exception("Only postgres database is supported in reflector") -engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs) - @subscribers_startup.append async def database_connect(_): - database = get_database() - await database.connect() + get_engine() @subscribers_shutdown.append async def database_disconnect(_): - database = get_database() - await database.disconnect() + global _engine + if _engine: + await _engine.dispose() + _engine = None diff --git a/server/reflector/db/base.py b/server/reflector/db/base.py new file mode 100644 index 00000000..0a50dad1 --- /dev/null +++ b/server/reflector/db/base.py @@ -0,0 +1,245 @@ +from datetime import datetime +from typing import Optional + +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Index, + Integer, + String, + Text, + text, +) +from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(AsyncAttrs, DeclarativeBase): + pass + + +class TranscriptModel(Base): + __tablename__ = "transcript" + + id: Mapped[str] = mapped_column(String, primary_key=True) + name: Mapped[Optional[str]] = mapped_column(String) + status: Mapped[Optional[str]] = mapped_column(String) + locked: Mapped[Optional[bool]] = mapped_column(Boolean) + duration: Mapped[Optional[float]] = mapped_column(Float) + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) + title: Mapped[Optional[str]] = mapped_column(String) + short_summary: Mapped[Optional[str]] = mapped_column(String) + long_summary: Mapped[Optional[str]] = mapped_column(String) + topics: Mapped[Optional[list]] = mapped_column(JSON) + events: Mapped[Optional[list]] = mapped_column(JSON) + participants: Mapped[Optional[list]] = mapped_column(JSON) + source_language: Mapped[Optional[str]] = mapped_column(String) + target_language: Mapped[Optional[str]] = mapped_column(String) + reviewed: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("false") + ) + audio_location: Mapped[str] = mapped_column( + String, nullable=False, server_default="local" + ) + user_id: Mapped[Optional[str]] = mapped_column(String) + share_mode: Mapped[str] = mapped_column( + String, nullable=False, server_default="private" + ) + meeting_id: Mapped[Optional[str]] = mapped_column(String) + recording_id: Mapped[Optional[str]] = mapped_column(String) + zulip_message_id: Mapped[Optional[int]] = mapped_column(Integer) + source_kind: Mapped[str] = mapped_column( + String, nullable=False + ) # Enum will be handled separately + audio_deleted: Mapped[Optional[bool]] = mapped_column(Boolean) + room_id: Mapped[Optional[str]] = mapped_column(String) + webvtt: Mapped[Optional[str]] = mapped_column(Text) + + __table_args__ = ( + Index("idx_transcript_recording_id", "recording_id"), + Index("idx_transcript_user_id", "user_id"), + Index("idx_transcript_created_at", "created_at"), + Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"), + Index("idx_transcript_room_id", "room_id"), + Index("idx_transcript_source_kind", "source_kind"), + Index("idx_transcript_room_id_created_at", "room_id", "created_at"), + ) + + +from sqlalchemy import Computed + +TranscriptModel.search_vector_en = Column( + "search_vector_en", + TSVECTOR, + Computed( + "setweight(to_tsvector('english', coalesce(title, '')), 'A') || " + "setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || " + "setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')", + persisted=True, + ), +) + + +class RoomModel(Base): + __tablename__ = "room" + + id: Mapped[str] = mapped_column(String, primary_key=True) + name: Mapped[str] = mapped_column(String, nullable=False, unique=True) + user_id: Mapped[str] = mapped_column(String, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + zulip_auto_post: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("false") + ) + zulip_stream: Mapped[Optional[str]] = mapped_column(String) + zulip_topic: Mapped[Optional[str]] = mapped_column(String) + is_locked: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("false") + ) + room_mode: Mapped[str] = mapped_column( + String, nullable=False, server_default="normal" + ) + recording_type: Mapped[str] = mapped_column( + String, nullable=False, server_default="cloud" + ) + recording_trigger: Mapped[str] = mapped_column( + String, nullable=False, server_default="automatic-2nd-participant" + ) + is_shared: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("false") + ) + webhook_url: Mapped[Optional[str]] = mapped_column(String) + webhook_secret: Mapped[Optional[str]] = mapped_column(String) + ics_url: Mapped[Optional[str]] = mapped_column(Text) + ics_fetch_interval: Mapped[Optional[int]] = mapped_column( + Integer, server_default=text("300") + ) + ics_enabled: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("false") + ) + ics_last_sync: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) + ics_last_etag: Mapped[Optional[str]] = mapped_column(Text) + + __table_args__ = ( + Index("idx_room_is_shared", "is_shared"), + Index("idx_room_ics_enabled", "ics_enabled"), + ) + + +class MeetingModel(Base): + __tablename__ = "meeting" + + id: Mapped[str] = mapped_column(String, primary_key=True) + room_name: Mapped[Optional[str]] = mapped_column(String) + room_url: Mapped[Optional[str]] = mapped_column(String) + host_room_url: Mapped[Optional[str]] = mapped_column(String) + start_date: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) + end_date: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) + room_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("room.id", ondelete="CASCADE") + ) + is_locked: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("false") + ) + room_mode: Mapped[str] = mapped_column( + String, nullable=False, server_default="normal" + ) + recording_type: Mapped[str] = mapped_column( + String, nullable=False, server_default="cloud" + ) + recording_trigger: Mapped[str] = mapped_column( + String, nullable=False, server_default="automatic-2nd-participant" + ) + num_clients: Mapped[int] = mapped_column( + Integer, nullable=False, server_default=text("0") + ) + is_active: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("true") + ) + calendar_event_id: Mapped[Optional[str]] = mapped_column( + String, + ForeignKey( + "calendar_event.id", + ondelete="SET NULL", + name="fk_meeting_calendar_event_id", + ), + ) + calendar_metadata: Mapped[Optional[dict]] = mapped_column(JSONB) + + __table_args__ = ( + Index("idx_meeting_room_id", "room_id"), + Index("idx_meeting_calendar_event", "calendar_event_id"), + ) + + +class MeetingConsentModel(Base): + __tablename__ = "meeting_consent" + + id: Mapped[str] = mapped_column(String, primary_key=True) + meeting_id: Mapped[str] = mapped_column( + String, ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False + ) + user_id: Mapped[Optional[str]] = mapped_column(String) + consent_given: Mapped[bool] = mapped_column(Boolean, nullable=False) + consent_timestamp: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + + +class RecordingModel(Base): + __tablename__ = "recording" + + id: Mapped[str] = mapped_column(String, primary_key=True) + meeting_id: Mapped[str] = mapped_column( + String, ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False + ) + url: Mapped[str] = mapped_column(String, nullable=False) + object_key: Mapped[str] = mapped_column(String, nullable=False) + duration: Mapped[Optional[float]] = mapped_column(Float) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + + __table_args__ = (Index("idx_recording_meeting_id", "meeting_id"),) + + +class CalendarEventModel(Base): + __tablename__ = "calendar_event" + + id: Mapped[str] = mapped_column(String, primary_key=True) + room_id: Mapped[str] = mapped_column( + String, ForeignKey("room.id", ondelete="CASCADE"), nullable=False + ) + ics_uid: Mapped[str] = mapped_column(Text, nullable=False) + title: Mapped[Optional[str]] = mapped_column(Text) + description: Mapped[Optional[str]] = mapped_column(Text) + start_time: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + attendees: Mapped[Optional[dict]] = mapped_column(JSONB) + location: Mapped[Optional[str]] = mapped_column(Text) + ics_raw_data: Mapped[Optional[str]] = mapped_column(Text) + last_synced: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + is_deleted: Mapped[bool] = mapped_column( + Boolean, nullable=False, server_default=text("false") + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + + __table_args__ = (Index("idx_calendar_event_room_start", "room_id", "start_time"),) + + +metadata = Base.metadata diff --git a/server/reflector/db/calendar_events.py b/server/reflector/db/calendar_events.py index 4a88d126..4fbcfa9b 100644 --- a/server/reflector/db/calendar_events.py +++ b/server/reflector/db/calendar_events.py @@ -3,42 +3,12 @@ from typing import Any import sqlalchemy as sa from pydantic import BaseModel, Field -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession -from reflector.db import get_database, metadata +from reflector.db.base import CalendarEventModel from reflector.utils import generate_uuid4 -calendar_events = sa.Table( - "calendar_event", - metadata, - sa.Column("id", sa.String, primary_key=True), - sa.Column( - "room_id", - sa.String, - sa.ForeignKey("room.id", ondelete="CASCADE", name="fk_calendar_event_room_id"), - nullable=False, - ), - sa.Column("ics_uid", sa.Text, nullable=False), - sa.Column("title", sa.Text), - sa.Column("description", sa.Text), - sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), - sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), - sa.Column("attendees", JSONB), - sa.Column("location", sa.Text), - sa.Column("ics_raw_data", sa.Text), - sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False), - sa.Column("is_deleted", sa.Boolean, nullable=False, server_default=sa.false()), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), - sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"), - sa.Index("idx_calendar_event_room_start", "room_id", "start_time"), - sa.Index( - "idx_calendar_event_deleted", - "is_deleted", - postgresql_where=sa.text("NOT is_deleted"), - ), -) - class CalendarEvent(BaseModel): id: str = Field(default_factory=generate_uuid4) @@ -58,124 +28,157 @@ class CalendarEvent(BaseModel): class CalendarEventController: - async def get_by_room( + async def get_upcoming_events( self, + session: AsyncSession, room_id: str, - include_deleted: bool = False, - start_after: datetime | None = None, - end_before: datetime | None = None, + current_time: datetime, + buffer_minutes: int = 15, ) -> list[CalendarEvent]: - query = calendar_events.select().where(calendar_events.c.room_id == room_id) - - if not include_deleted: - query = query.where(calendar_events.c.is_deleted == False) - - if start_after: - query = query.where(calendar_events.c.start_time >= start_after) - - if end_before: - query = query.where(calendar_events.c.end_time <= end_before) - - query = query.order_by(calendar_events.c.start_time.asc()) - - results = await get_database().fetch_all(query) - return [CalendarEvent(**result) for result in results] - - async def get_upcoming( - self, room_id: str, minutes_ahead: int = 120 - ) -> list[CalendarEvent]: - """Get upcoming events for a room within the specified minutes, including currently happening events.""" - now = datetime.now(timezone.utc) - future_time = now + timedelta(minutes=minutes_ahead) + buffer_time = current_time + timedelta(minutes=buffer_minutes) query = ( - calendar_events.select() + select(CalendarEventModel) .where( sa.and_( - calendar_events.c.room_id == room_id, - calendar_events.c.is_deleted == False, - calendar_events.c.start_time <= future_time, - calendar_events.c.end_time >= now, + CalendarEventModel.room_id == room_id, + CalendarEventModel.start_time <= buffer_time, + CalendarEventModel.end_time > current_time, ) ) - .order_by(calendar_events.c.start_time.asc()) + .order_by(CalendarEventModel.start_time) ) - results = await get_database().fetch_all(query) - return [CalendarEvent(**result) for result in results] + result = await session.execute(query) + return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] - async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None: - query = calendar_events.select().where( + async def get_by_id( + self, session: AsyncSession, event_id: str + ) -> CalendarEvent | None: + query = select(CalendarEventModel).where(CalendarEventModel.id == event_id) + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: + return None + return CalendarEvent(**row.__dict__) + + async def get_by_ics_uid( + self, session: AsyncSession, room_id: str, ics_uid: str + ) -> CalendarEvent | None: + query = select(CalendarEventModel).where( sa.and_( - calendar_events.c.room_id == room_id, - calendar_events.c.ics_uid == ics_uid, + CalendarEventModel.room_id == room_id, + CalendarEventModel.ics_uid == ics_uid, ) ) - result = await get_database().fetch_one(query) - return CalendarEvent(**result) if result else None + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: + return None + return CalendarEvent(**row.__dict__) - async def upsert(self, event: CalendarEvent) -> CalendarEvent: - existing = await self.get_by_ics_uid(event.room_id, event.ics_uid) + async def upsert( + self, session: AsyncSession, event: CalendarEvent + ) -> CalendarEvent: + existing = await self.get_by_ics_uid(session, event.room_id, event.ics_uid) if existing: - event.id = existing.id - event.created_at = existing.created_at event.updated_at = datetime.now(timezone.utc) - query = ( - calendar_events.update() - .where(calendar_events.c.id == existing.id) - .values(**event.model_dump()) + update(CalendarEventModel) + .where(CalendarEventModel.id == existing.id) + .values(**event.model_dump(exclude={"id"})) ) + await session.execute(query) + await session.commit() + return event else: - query = calendar_events.insert().values(**event.model_dump()) + new_event = CalendarEventModel(**event.model_dump()) + session.add(new_event) + await session.commit() + return event - await get_database().execute(query) - return event - - async def soft_delete_missing( - self, room_id: str, current_ics_uids: list[str] + async def delete_old_events( + self, session: AsyncSession, room_id: str, cutoff_date: datetime ) -> int: - """Soft delete future events that are no longer in the calendar.""" - now = datetime.now(timezone.utc) - - select_query = calendar_events.select().where( + query = delete(CalendarEventModel).where( sa.and_( - calendar_events.c.room_id == room_id, - calendar_events.c.start_time > now, - calendar_events.c.is_deleted == False, - calendar_events.c.ics_uid.notin_(current_ics_uids) - if current_ics_uids - else True, + CalendarEventModel.room_id == room_id, + CalendarEventModel.end_time < cutoff_date, ) ) + result = await session.execute(query) + await session.commit() + return result.rowcount - to_delete = await get_database().fetch_all(select_query) - delete_count = len(to_delete) - - if delete_count > 0: - update_query = ( - calendar_events.update() - .where( - sa.and_( - calendar_events.c.room_id == room_id, - calendar_events.c.start_time > now, - calendar_events.c.is_deleted == False, - calendar_events.c.ics_uid.notin_(current_ics_uids) - if current_ics_uids - else True, - ) + async def delete_events_not_in_list( + self, session: AsyncSession, room_id: str, keep_ics_uids: list[str] + ) -> int: + if not keep_ics_uids: + query = delete(CalendarEventModel).where( + CalendarEventModel.room_id == room_id + ) + else: + query = delete(CalendarEventModel).where( + sa.and_( + CalendarEventModel.room_id == room_id, + CalendarEventModel.ics_uid.notin_(keep_ics_uids), ) - .values(is_deleted=True, updated_at=now) ) - await get_database().execute(update_query) + result = await session.execute(query) + await session.commit() + return result.rowcount - return delete_count + async def get_by_room( + self, session: AsyncSession, room_id: str, include_deleted: bool = True + ) -> list[CalendarEvent]: + query = select(CalendarEventModel).where(CalendarEventModel.room_id == room_id) + if not include_deleted: + query = query.where(CalendarEventModel.is_deleted == False) + result = await session.execute(query) + return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] - async def delete_by_room(self, room_id: str) -> int: - query = calendar_events.delete().where(calendar_events.c.room_id == room_id) - result = await get_database().execute(query) + async def get_upcoming( + self, session: AsyncSession, room_id: str, minutes_ahead: int = 120 + ) -> list[CalendarEvent]: + now = datetime.now(timezone.utc) + buffer_time = now + timedelta(minutes=minutes_ahead) + + query = ( + select(CalendarEventModel) + .where( + sa.and_( + CalendarEventModel.room_id == room_id, + CalendarEventModel.start_time <= buffer_time, + CalendarEventModel.end_time > now, + CalendarEventModel.is_deleted == False, + ) + ) + .order_by(CalendarEventModel.start_time) + ) + + result = await session.execute(query) + return [CalendarEvent(**row.__dict__) for row in result.scalars().all()] + + async def soft_delete_missing( + self, session: AsyncSession, room_id: str, current_ics_uids: list[str] + ) -> int: + query = ( + update(CalendarEventModel) + .where( + sa.and_( + CalendarEventModel.room_id == room_id, + CalendarEventModel.ics_uid.notin_(current_ics_uids) + if current_ics_uids + else True, + CalendarEventModel.end_time > datetime.now(timezone.utc), + ) + ) + .values(is_deleted=True) + ) + result = await session.execute(query) + await session.commit() return result.rowcount diff --git a/server/reflector/db/meetings.py b/server/reflector/db/meetings.py index 12a0c187..02a9ecd1 100644 --- a/server/reflector/db/meetings.py +++ b/server/reflector/db/meetings.py @@ -3,77 +3,13 @@ from typing import Any, Literal import sqlalchemy as sa from pydantic import BaseModel, Field -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession -from reflector.db import get_database, metadata +from reflector.db.base import MeetingConsentModel, MeetingModel from reflector.db.rooms import Room from reflector.utils import generate_uuid4 -meetings = sa.Table( - "meeting", - metadata, - sa.Column("id", sa.String, primary_key=True), - sa.Column("room_name", sa.String), - sa.Column("room_url", sa.String), - sa.Column("host_room_url", sa.String), - sa.Column("start_date", sa.DateTime(timezone=True)), - sa.Column("end_date", sa.DateTime(timezone=True)), - sa.Column( - "room_id", - sa.String, - sa.ForeignKey("room.id", ondelete="CASCADE"), - nullable=True, - ), - sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()), - sa.Column("room_mode", sa.String, nullable=False, server_default="normal"), - sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"), - sa.Column( - "recording_trigger", - sa.String, - nullable=False, - server_default="automatic-2nd-participant", - ), - sa.Column( - "num_clients", - sa.Integer, - nullable=False, - server_default=sa.text("0"), - ), - sa.Column( - "is_active", - sa.Boolean, - nullable=False, - server_default=sa.true(), - ), - sa.Column( - "calendar_event_id", - sa.String, - sa.ForeignKey( - "calendar_event.id", - ondelete="SET NULL", - name="fk_meeting_calendar_event_id", - ), - ), - sa.Column("calendar_metadata", JSONB), - sa.Index("idx_meeting_room_id", "room_id"), - sa.Index("idx_meeting_calendar_event", "calendar_event_id"), -) - -meeting_consent = sa.Table( - "meeting_consent", - metadata, - sa.Column("id", sa.String, primary_key=True), - sa.Column( - "meeting_id", - sa.String, - sa.ForeignKey("meeting.id", ondelete="CASCADE"), - nullable=False, - ), - sa.Column("user_id", sa.String), - sa.Column("consent_given", sa.Boolean, nullable=False), - sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False), -) - class MeetingConsent(BaseModel): id: str = Field(default_factory=generate_uuid4) @@ -106,6 +42,7 @@ class Meeting(BaseModel): class MeetingController: async def create( self, + session: AsyncSession, id: str, room_name: str, room_url: str, @@ -131,170 +68,198 @@ class MeetingController: calendar_event_id=calendar_event_id, calendar_metadata=calendar_metadata, ) - query = meetings.insert().values(**meeting.model_dump()) - await get_database().execute(query) + new_meeting = MeetingModel(**meeting.model_dump()) + session.add(new_meeting) + await session.commit() return meeting - async def get_all_active(self) -> list[Meeting]: - query = meetings.select().where(meetings.c.is_active) - return await get_database().fetch_all(query) + async def get_all_active(self, session: AsyncSession) -> list[Meeting]: + query = select(MeetingModel).where(MeetingModel.is_active) + result = await session.execute(query) + return [Meeting(**row.__dict__) for row in result.scalars().all()] async def get_by_room_name( self, + session: AsyncSession, room_name: str, ) -> Meeting | None: """ Get a meeting by room name. For backward compatibility, returns the most recent meeting. """ - end_date = getattr(meetings.c, "end_date") query = ( - meetings.select() - .where(meetings.c.room_name == room_name) - .order_by(end_date.desc()) + select(MeetingModel) + .where(MeetingModel.room_name == room_name) + .order_by(MeetingModel.end_date.desc()) ) - result = await get_database().fetch_one(query) - if not result: + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: return None + return Meeting(**row.__dict__) - return Meeting(**result) - - async def get_active(self, room: Room, current_time: datetime) -> Meeting | None: + async def get_active( + self, session: AsyncSession, room: Room, current_time: datetime + ) -> Meeting | None: """ Get latest active meeting for a room. For backward compatibility, returns the most recent active meeting. """ - end_date = getattr(meetings.c, "end_date") query = ( - meetings.select() + select(MeetingModel) .where( sa.and_( - meetings.c.room_id == room.id, - meetings.c.end_date > current_time, - meetings.c.is_active, + MeetingModel.room_id == room.id, + MeetingModel.end_date > current_time, + MeetingModel.is_active, ) ) - .order_by(end_date.desc()) + .order_by(MeetingModel.end_date.desc()) ) - result = await get_database().fetch_one(query) - if not result: + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: return None - - return Meeting(**result) + return Meeting(**row.__dict__) async def get_all_active_for_room( - self, room: Room, current_time: datetime + self, session: AsyncSession, room: Room, current_time: datetime ) -> list[Meeting]: - end_date = getattr(meetings.c, "end_date") query = ( - meetings.select() + select(MeetingModel) .where( sa.and_( - meetings.c.room_id == room.id, - meetings.c.end_date > current_time, - meetings.c.is_active, + MeetingModel.room_id == room.id, + MeetingModel.end_date > current_time, + MeetingModel.is_active, ) ) - .order_by(end_date.desc()) + .order_by(MeetingModel.end_date.desc()) ) - results = await get_database().fetch_all(query) - return [Meeting(**result) for result in results] + result = await session.execute(query) + return [Meeting(**row.__dict__) for row in result.scalars().all()] async def get_active_by_calendar_event( - self, room: Room, calendar_event_id: str, current_time: datetime + self, + session: AsyncSession, + room: Room, + calendar_event_id: str, + current_time: datetime, ) -> Meeting | None: """ Get active meeting for a specific calendar event. """ - query = meetings.select().where( + query = select(MeetingModel).where( sa.and_( - meetings.c.room_id == room.id, - meetings.c.calendar_event_id == calendar_event_id, - meetings.c.end_date > current_time, - meetings.c.is_active, + MeetingModel.room_id == room.id, + MeetingModel.calendar_event_id == calendar_event_id, + MeetingModel.end_date > current_time, + MeetingModel.is_active, ) ) - result = await get_database().fetch_one(query) - if not result: + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: return None - return Meeting(**result) + return Meeting(**row.__dict__) - async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None: - query = meetings.select().where(meetings.c.id == meeting_id) - result = await get_database().fetch_one(query) - if not result: + async def get_by_id( + self, session: AsyncSession, meeting_id: str, **kwargs + ) -> Meeting | None: + query = select(MeetingModel).where(MeetingModel.id == meeting_id) + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: return None - return Meeting(**result) + return Meeting(**row.__dict__) - async def get_by_calendar_event(self, calendar_event_id: str) -> Meeting | None: - query = meetings.select().where( - meetings.c.calendar_event_id == calendar_event_id + async def get_by_calendar_event( + self, session: AsyncSession, calendar_event_id: str + ) -> Meeting | None: + query = select(MeetingModel).where( + MeetingModel.calendar_event_id == calendar_event_id ) - result = await get_database().fetch_one(query) - if not result: + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: return None - return Meeting(**result) + return Meeting(**row.__dict__) - async def update_meeting(self, meeting_id: str, **kwargs): - query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs) - await get_database().execute(query) + async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs): + query = ( + update(MeetingModel).where(MeetingModel.id == meeting_id).values(**kwargs) + ) + await session.execute(query) + await session.commit() class MeetingConsentController: - async def get_by_meeting_id(self, meeting_id: str) -> list[MeetingConsent]: - query = meeting_consent.select().where( - meeting_consent.c.meeting_id == meeting_id + async def get_by_meeting_id( + self, session: AsyncSession, meeting_id: str + ) -> list[MeetingConsent]: + query = select(MeetingConsentModel).where( + MeetingConsentModel.meeting_id == meeting_id ) - results = await get_database().fetch_all(query) - return [MeetingConsent(**result) for result in results] + result = await session.execute(query) + return [MeetingConsent(**row.__dict__) for row in result.scalars().all()] async def get_by_meeting_and_user( - self, meeting_id: str, user_id: str + self, session: AsyncSession, meeting_id: str, user_id: str ) -> MeetingConsent | None: """Get existing consent for a specific user and meeting""" - query = meeting_consent.select().where( - meeting_consent.c.meeting_id == meeting_id, - meeting_consent.c.user_id == user_id, + query = select(MeetingConsentModel).where( + sa.and_( + MeetingConsentModel.meeting_id == meeting_id, + MeetingConsentModel.user_id == user_id, + ) ) - result = await get_database().fetch_one(query) - if result is None: + result = await session.execute(query) + row = result.scalar_one_or_none() + if row is None: return None - return MeetingConsent(**result) + return MeetingConsent(**row.__dict__) - async def upsert(self, consent: MeetingConsent) -> MeetingConsent: + async def upsert( + self, session: AsyncSession, consent: MeetingConsent + ) -> MeetingConsent: if consent.user_id: # For authenticated users, check if consent already exists # not transactional but we're ok with that; the consents ain't deleted anyways existing = await self.get_by_meeting_and_user( - consent.meeting_id, consent.user_id + session, consent.meeting_id, consent.user_id ) if existing: query = ( - meeting_consent.update() - .where(meeting_consent.c.id == existing.id) + update(MeetingConsentModel) + .where(MeetingConsentModel.id == existing.id) .values( consent_given=consent.consent_given, consent_timestamp=consent.consent_timestamp, ) ) - await get_database().execute(query) + await session.execute(query) + await session.commit() - existing.consent_given = consent.consent_given - existing.consent_timestamp = consent.consent_timestamp - return existing + existing.consent_given = consent.consent_given + existing.consent_timestamp = consent.consent_timestamp + return existing - query = meeting_consent.insert().values(**consent.model_dump()) - await get_database().execute(query) + new_consent = MeetingConsentModel(**consent.model_dump()) + session.add(new_consent) + await session.commit() return consent - async def has_any_denial(self, meeting_id: str) -> bool: + async def has_any_denial(self, session: AsyncSession, meeting_id: str) -> bool: """Check if any participant denied consent for this meeting""" - query = meeting_consent.select().where( - meeting_consent.c.meeting_id == meeting_id, - meeting_consent.c.consent_given.is_(False), + query = select(MeetingConsentModel).where( + sa.and_( + MeetingConsentModel.meeting_id == meeting_id, + MeetingConsentModel.consent_given.is_(False), + ) ) - result = await get_database().fetch_one(query) - return result is not None + result = await session.execute(query) + row = result.scalar_one_or_none() + return row is not None meetings_controller = MeetingController() diff --git a/server/reflector/db/recordings.py b/server/reflector/db/recordings.py index 0d05790d..ee7a7be1 100644 --- a/server/reflector/db/recordings.py +++ b/server/reflector/db/recordings.py @@ -1,61 +1,79 @@ from datetime import datetime -from typing import Literal -import sqlalchemy as sa from pydantic import BaseModel, Field +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession -from reflector.db import get_database, metadata +from reflector.db.base import RecordingModel from reflector.utils import generate_uuid4 -recordings = sa.Table( - "recording", - metadata, - sa.Column("id", sa.String, primary_key=True), - sa.Column("bucket_name", sa.String, nullable=False), - sa.Column("object_key", sa.String, nullable=False), - sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False), - sa.Column( - "status", - sa.String, - nullable=False, - server_default="pending", - ), - sa.Column("meeting_id", sa.String), - sa.Index("idx_recording_meeting_id", "meeting_id"), -) - class Recording(BaseModel): id: str = Field(default_factory=generate_uuid4) - bucket_name: str + meeting_id: str + url: str object_key: str - recorded_at: datetime - status: Literal["pending", "processing", "completed", "failed"] = "pending" - meeting_id: str | None = None + duration: float | None = None + created_at: datetime class RecordingController: - async def create(self, recording: Recording): - query = recordings.insert().values(**recording.model_dump()) - await get_database().execute(query) + async def create( + self, + session: AsyncSession, + meeting_id: str, + url: str, + object_key: str, + duration: float | None = None, + created_at: datetime | None = None, + ): + if created_at is None: + from datetime import timezone + + created_at = datetime.now(timezone.utc) + + recording = Recording( + meeting_id=meeting_id, + url=url, + object_key=object_key, + duration=duration, + created_at=created_at, + ) + new_recording = RecordingModel(**recording.model_dump()) + session.add(new_recording) + await session.commit() return recording - async def get_by_id(self, id: str) -> Recording: - query = recordings.select().where(recordings.c.id == id) - result = await get_database().fetch_one(query) - return Recording(**result) if result else None + async def get_by_id( + self, session: AsyncSession, recording_id: str + ) -> Recording | None: + """ + Get a recording by id + """ + query = select(RecordingModel).where(RecordingModel.id == recording_id) + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: + return None + return Recording(**row.__dict__) - async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording: - query = recordings.select().where( - recordings.c.bucket_name == bucket_name, - recordings.c.object_key == object_key, - ) - result = await get_database().fetch_one(query) - return Recording(**result) if result else None + async def get_by_meeting_id( + self, session: AsyncSession, meeting_id: str + ) -> list[Recording]: + """ + Get all recordings for a meeting + """ + query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id) + result = await session.execute(query) + return [Recording(**row.__dict__) for row in result.scalars().all()] - async def remove_by_id(self, id: str) -> None: - query = recordings.delete().where(recordings.c.id == id) - await get_database().execute(query) + async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None: + """ + Remove a recording by id + """ + query = delete(RecordingModel).where(RecordingModel.id == recording_id) + await session.execute(query) + await session.commit() recordings_controller = RecordingController() diff --git a/server/reflector/db/rooms.py b/server/reflector/db/rooms.py index 396c818a..05e6458d 100644 --- a/server/reflector/db/rooms.py +++ b/server/reflector/db/rooms.py @@ -3,57 +3,15 @@ from datetime import datetime, timezone from sqlite3 import IntegrityError from typing import Literal -import sqlalchemy from fastapi import HTTPException from pydantic import BaseModel, Field -from sqlalchemy.sql import false, or_ +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import or_ -from reflector.db import get_database, metadata +from reflector.db.base import RoomModel from reflector.utils import generate_uuid4 -rooms = sqlalchemy.Table( - "room", - metadata, - sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True), - sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False), - sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False), - sqlalchemy.Column( - "zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false() - ), - sqlalchemy.Column("zulip_stream", sqlalchemy.String), - sqlalchemy.Column("zulip_topic", sqlalchemy.String), - sqlalchemy.Column( - "is_locked", sqlalchemy.Boolean, nullable=False, server_default=false() - ), - sqlalchemy.Column( - "room_mode", sqlalchemy.String, nullable=False, server_default="normal" - ), - sqlalchemy.Column( - "recording_type", sqlalchemy.String, nullable=False, server_default="cloud" - ), - sqlalchemy.Column( - "recording_trigger", - sqlalchemy.String, - nullable=False, - server_default="automatic-2nd-participant", - ), - sqlalchemy.Column( - "is_shared", sqlalchemy.Boolean, nullable=False, server_default=false() - ), - sqlalchemy.Column("webhook_url", sqlalchemy.String, nullable=True), - sqlalchemy.Column("webhook_secret", sqlalchemy.String, nullable=True), - sqlalchemy.Column("ics_url", sqlalchemy.Text), - sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"), - sqlalchemy.Column( - "ics_enabled", sqlalchemy.Boolean, nullable=False, server_default=false() - ), - sqlalchemy.Column("ics_last_sync", sqlalchemy.DateTime(timezone=True)), - sqlalchemy.Column("ics_last_etag", sqlalchemy.Text), - sqlalchemy.Index("idx_room_is_shared", "is_shared"), - sqlalchemy.Index("idx_room_ics_enabled", "ics_enabled"), -) - class Room(BaseModel): id: str = Field(default_factory=generate_uuid4) @@ -82,6 +40,7 @@ class Room(BaseModel): class RoomController: async def get_all( self, + session: AsyncSession, user_id: str | None = None, order_by: str | None = None, return_query: bool = False, @@ -97,9 +56,9 @@ class RoomController: """ query = rooms.select() if user_id is not None: - query = query.where(or_(rooms.c.user_id == user_id, rooms.c.is_shared)) + query = query.where(or_(RoomModel.user_id == user_id, RoomModel.is_shared)) else: - query = query.where(rooms.c.is_shared) + query = query.where(RoomModel.is_shared) if order_by is not None: field = getattr(rooms.c, order_by[1:]) @@ -110,11 +69,12 @@ class RoomController: if return_query: return query - results = await get_database().fetch_all(query) - return results + result = await session.execute(query) + return [Room(**row) for row in result.mappings().all()] async def add( self, + session: AsyncSession, name: str, user_id: str, zulip_auto_post: bool, @@ -154,23 +114,27 @@ class RoomController: ics_fetch_interval=ics_fetch_interval, ics_enabled=ics_enabled, ) - query = rooms.insert().values(**room.model_dump()) + new_room = RoomModel(**room.model_dump()) + session.add(new_room) try: - await get_database().execute(query) + await session.commit() except IntegrityError: raise HTTPException(status_code=400, detail="Room name is not unique") return room - async def update(self, room: Room, values: dict, mutate=True): + async def update( + self, session: AsyncSession, room: Room, values: dict, mutate=True + ): """ Update a room fields with key/values in values """ if values.get("webhook_url") and not values.get("webhook_secret"): values["webhook_secret"] = secrets.token_urlsafe(32) - query = rooms.update().where(rooms.c.id == room.id).values(**values) + query = update(rooms).where(RoomModel.id == room.id).values(**values) try: - await get_database().execute(query) + await session.execute(query) + await session.commit() except IntegrityError: raise HTTPException(status_code=400, detail="Room name is not unique") @@ -178,67 +142,79 @@ class RoomController: for key, value in values.items(): setattr(room, key, value) - async def get_by_id(self, room_id: str, **kwargs) -> Room | None: + async def get_by_id( + self, session: AsyncSession, room_id: str, **kwargs + ) -> Room | None: """ Get a room by id """ - query = rooms.select().where(rooms.c.id == room_id) + query = select(rooms).where(RoomModel.id == room_id) if "user_id" in kwargs: - query = query.where(rooms.c.user_id == kwargs["user_id"]) - result = await get_database().fetch_one(query) - if not result: + query = query.where(RoomModel.user_id == kwargs["user_id"]) + result = await session.execute(query) + row = result.mappings().first() + if not row: return None - return Room(**result) + return Room(**row) - async def get_by_name(self, room_name: str, **kwargs) -> Room | None: + async def get_by_name( + self, session: AsyncSession, room_name: str, **kwargs + ) -> Room | None: """ Get a room by name """ - query = rooms.select().where(rooms.c.name == room_name) + query = select(rooms).where(RoomModel.name == room_name) if "user_id" in kwargs: - query = query.where(rooms.c.user_id == kwargs["user_id"]) - result = await get_database().fetch_one(query) - if not result: + query = query.where(RoomModel.user_id == kwargs["user_id"]) + result = await session.execute(query) + row = result.mappings().first() + if not row: return None - return Room(**result) + return Room(**row) - async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Room: + async def get_by_id_for_http( + self, session: AsyncSession, meeting_id: str, user_id: str | None + ) -> Room: """ Get a room by ID for HTTP request. If not found, it will raise a 404 error. """ - query = rooms.select().where(rooms.c.id == meeting_id) - result = await get_database().fetch_one(query) - if not result: + query = select(rooms).where(RoomModel.id == meeting_id) + result = await session.execute(query) + row = result.mappings().first() + if not row: raise HTTPException(status_code=404, detail="Room not found") - room = Room(**result) + room = Room(**row) return room - async def get_ics_enabled(self) -> list[Room]: - query = rooms.select().where( - rooms.c.ics_enabled == True, rooms.c.ics_url != None + async def get_ics_enabled(self, session: AsyncSession) -> list[Room]: + query = select(rooms).where( + RoomModel.ics_enabled == True, RoomModel.ics_url != None ) - results = await get_database().fetch_all(query) - return [Room(**result) for result in results] + result = await session.execute(query) + results = result.mappings().all() + return [Room(**r) for r in results] async def remove_by_id( self, + session: AsyncSession, room_id: str, user_id: str | None = None, ) -> None: """ Remove a room by id """ - room = await self.get_by_id(room_id, user_id=user_id) + room = await self.get_by_id(session, room_id, user_id=user_id) if not room: return if user_id is not None and room.user_id != user_id: return - query = rooms.delete().where(rooms.c.id == room_id) - await get_database().execute(query) + query = delete(rooms).where(RoomModel.id == room_id) + await session.execute(query) + await session.commit() rooms_controller = RoomController() diff --git a/server/reflector/db/search.py b/server/reflector/db/search.py index caa21c65..32f0513a 100644 --- a/server/reflector/db/search.py +++ b/server/reflector/db/search.py @@ -8,7 +8,6 @@ from typing import Annotated, Any, Dict, Iterator import sqlalchemy import webvtt -from databases.interfaces import Record as DbRecord from fastapi import HTTPException from pydantic import ( BaseModel, @@ -20,11 +19,10 @@ from pydantic import ( constr, field_serializer, ) +from sqlalchemy.ext.asyncio import AsyncSession -from reflector.db import get_database -from reflector.db.rooms import rooms -from reflector.db.transcripts import SourceKind, TranscriptStatus, transcripts -from reflector.db.utils import is_postgresql +from reflector.db.base import RoomModel, TranscriptModel +from reflector.db.transcripts import SourceKind, TranscriptStatus from reflector.logger import logger from reflector.utils.string import NonEmptyString, try_parse_non_empty_string @@ -331,36 +329,30 @@ class SearchController: @classmethod async def search_transcripts( - cls, params: SearchParameters + cls, session: AsyncSession, params: SearchParameters ) -> tuple[list[SearchResult], int]: """ Full-text search for transcripts using PostgreSQL tsvector. Returns (results, total_count). """ - if not is_postgresql(): - logger.warning( - "Full-text search requires PostgreSQL. Returning empty results." - ) - return [], 0 - base_columns = [ - transcripts.c.id, - transcripts.c.title, - transcripts.c.created_at, - transcripts.c.duration, - transcripts.c.status, - transcripts.c.user_id, - transcripts.c.room_id, - transcripts.c.source_kind, - transcripts.c.webvtt, - transcripts.c.long_summary, + TranscriptModel.id, + TranscriptModel.title, + TranscriptModel.created_at, + TranscriptModel.duration, + TranscriptModel.status, + TranscriptModel.user_id, + TranscriptModel.room_id, + TranscriptModel.source_kind, + TranscriptModel.webvtt, + TranscriptModel.long_summary, sqlalchemy.case( ( - transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None), + TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None), "Deleted Room", ), - else_=rooms.c.name, + else_=RoomModel.name, ).label("room_name"), ] search_query = None @@ -369,7 +361,7 @@ class SearchController: "english", params.query_text ) rank_column = sqlalchemy.func.ts_rank( - transcripts.c.search_vector_en, + TranscriptModel.search_vector_en, search_query, 32, # normalization flag: rank/(rank+1) for 0-1 range ).label("rank") @@ -378,46 +370,52 @@ class SearchController: columns = base_columns + [rank_column] base_query = sqlalchemy.select(columns).select_from( - transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True) + TranscriptModel.__table__.join( + RoomModel.__table__, + TranscriptModel.room_id == RoomModel.id, + isouter=True, + ) ) if params.query_text is not None: # because already initialized based on params.query_text presence above assert search_query is not None base_query = base_query.where( - transcripts.c.search_vector_en.op("@@")(search_query) + TranscriptModel.search_vector_en.op("@@")(search_query) ) if params.user_id: base_query = base_query.where( sqlalchemy.or_( - transcripts.c.user_id == params.user_id, rooms.c.is_shared + TranscriptModel.user_id == params.user_id, RoomModel.is_shared ) ) else: - base_query = base_query.where(rooms.c.is_shared) + base_query = base_query.where(RoomModel.is_shared) if params.room_id: - base_query = base_query.where(transcripts.c.room_id == params.room_id) + base_query = base_query.where(TranscriptModel.room_id == params.room_id) if params.source_kind: base_query = base_query.where( - transcripts.c.source_kind == params.source_kind + TranscriptModel.source_kind == params.source_kind ) if params.query_text is not None: order_by = sqlalchemy.desc(sqlalchemy.text("rank")) else: - order_by = sqlalchemy.desc(transcripts.c.created_at) + order_by = sqlalchemy.desc(TranscriptModel.created_at) query = base_query.order_by(order_by).limit(params.limit).offset(params.offset) - rs = await get_database().fetch_all(query) + result = await session.execute(query) + rs = result.mappings().all() count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from( base_query.alias("search_results") ) - total = await get_database().fetch_val(count_query) + count_result = await session.execute(count_query) + total = count_result.scalar() - def _process_result(r: DbRecord) -> SearchResult: + def _process_result(r: dict) -> SearchResult: r_dict: Dict[str, Any] = dict(r) webvtt_raw: str | None = r_dict.pop("webvtt", None) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 47148995..c4da4805 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -7,17 +7,14 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any, Literal -import sqlalchemy from fastapi import HTTPException from pydantic import BaseModel, ConfigDict, Field, field_serializer -from sqlalchemy import Enum -from sqlalchemy.dialects.postgresql import TSVECTOR -from sqlalchemy.sql import false, or_ +from sqlalchemy import delete, insert, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import or_ -from reflector.db import get_database, metadata +from reflector.db.base import RoomModel, TranscriptModel from reflector.db.recordings import recordings_controller -from reflector.db.rooms import rooms -from reflector.db.utils import is_postgresql from reflector.logger import logger from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings @@ -32,91 +29,6 @@ class SourceKind(enum.StrEnum): FILE = enum.auto() -transcripts = sqlalchemy.Table( - "transcript", - metadata, - sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("name", sqlalchemy.String), - sqlalchemy.Column("status", sqlalchemy.String), - sqlalchemy.Column("locked", sqlalchemy.Boolean), - sqlalchemy.Column("duration", sqlalchemy.Float), - sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)), - sqlalchemy.Column("title", sqlalchemy.String), - sqlalchemy.Column("short_summary", sqlalchemy.String), - sqlalchemy.Column("long_summary", sqlalchemy.String), - sqlalchemy.Column("topics", sqlalchemy.JSON), - sqlalchemy.Column("events", sqlalchemy.JSON), - sqlalchemy.Column("participants", sqlalchemy.JSON), - sqlalchemy.Column("source_language", sqlalchemy.String), - sqlalchemy.Column("target_language", sqlalchemy.String), - sqlalchemy.Column( - "reviewed", sqlalchemy.Boolean, nullable=False, server_default=false() - ), - sqlalchemy.Column( - "audio_location", - sqlalchemy.String, - nullable=False, - server_default="local", - ), - # with user attached, optional - sqlalchemy.Column("user_id", sqlalchemy.String), - sqlalchemy.Column( - "share_mode", - sqlalchemy.String, - nullable=False, - server_default="private", - ), - sqlalchemy.Column( - "meeting_id", - sqlalchemy.String, - ), - sqlalchemy.Column("recording_id", sqlalchemy.String), - sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer), - sqlalchemy.Column( - "source_kind", - Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]), - nullable=False, - ), - # indicative field: whether associated audio is deleted - # the main "audio deleted" is the presence of the audio itself / consents not-given - # same field could've been in recording/meeting, and it's maybe even ok to dupe it at need - sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), - sqlalchemy.Column("room_id", sqlalchemy.String), - sqlalchemy.Column("webvtt", sqlalchemy.Text), - sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), - sqlalchemy.Index("idx_transcript_user_id", "user_id"), - sqlalchemy.Index("idx_transcript_created_at", "created_at"), - sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"), - sqlalchemy.Index("idx_transcript_room_id", "room_id"), - sqlalchemy.Index("idx_transcript_source_kind", "source_kind"), - sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"), -) - -# Add PostgreSQL-specific full-text search column -# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py -if is_postgresql(): - transcripts.append_column( - sqlalchemy.Column( - "search_vector_en", - TSVECTOR, - sqlalchemy.Computed( - "setweight(to_tsvector('english', coalesce(title, '')), 'A') || " - "setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || " - "setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')", - persisted=True, - ), - ) - ) - # Add GIN index for the search vector - transcripts.append_constraint( - sqlalchemy.Index( - "idx_transcript_search_vector_en", - "search_vector_en", - postgresql_using="gin", - ) - ) - - def generate_transcript_name() -> str: now = datetime.now(timezone.utc) return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" @@ -359,6 +271,7 @@ class Transcript(BaseModel): class TranscriptController: async def get_all( self, + session: AsyncSession, user_id: str | None = None, order_by: str | None = None, filter_empty: bool | None = False, @@ -383,102 +296,111 @@ class TranscriptController: - `search_term`: filter transcripts by search term """ - query = transcripts.select().join( - rooms, transcripts.c.room_id == rooms.c.id, isouter=True + query = select(TranscriptModel).join( + RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True ) if user_id: query = query.where( - or_(transcripts.c.user_id == user_id, rooms.c.is_shared) + or_(TranscriptModel.user_id == user_id, RoomModel.is_shared) ) else: - query = query.where(rooms.c.is_shared) + query = query.where(RoomModel.is_shared) if source_kind: - query = query.where(transcripts.c.source_kind == source_kind) + query = query.where(TranscriptModel.source_kind == source_kind) if room_id: - query = query.where(transcripts.c.room_id == room_id) + query = query.where(TranscriptModel.room_id == room_id) if search_term: - query = query.where(transcripts.c.title.ilike(f"%{search_term}%")) + query = query.where(TranscriptModel.title.ilike(f"%{search_term}%")) # Exclude heavy JSON columns from list queries transcript_columns = [ - col for col in transcripts.c if col.name not in exclude_columns + col + for col in TranscriptModel.__table__.c + if col.name not in exclude_columns ] query = query.with_only_columns( - transcript_columns - + [ - rooms.c.name.label("room_name"), - ] + *transcript_columns, + RoomModel.name.label("room_name"), ) if order_by is not None: - field = getattr(transcripts.c, order_by[1:]) + field = getattr(TranscriptModel, order_by[1:]) if order_by.startswith("-"): field = field.desc() query = query.order_by(field) if filter_empty: - query = query.filter(transcripts.c.status != "idle") + query = query.filter(TranscriptModel.status != "idle") if filter_recording: - query = query.filter(transcripts.c.status != "recording") + query = query.filter(TranscriptModel.status != "recording") # print(query.compile(compile_kwargs={"literal_binds": True})) if return_query: return query - results = await get_database().fetch_all(query) - return results + result = await session.execute(query) + return [dict(row) for row in result.mappings().all()] - async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: + async def get_by_id( + self, session: AsyncSession, transcript_id: str, **kwargs + ) -> Transcript | None: """ Get a transcript by id """ - query = transcripts.select().where(transcripts.c.id == transcript_id) + query = select(TranscriptModel).where(TranscriptModel.id == transcript_id) if "user_id" in kwargs: - query = query.where(transcripts.c.user_id == kwargs["user_id"]) - result = await get_database().fetch_one(query) - if not result: + query = query.where(TranscriptModel.user_id == kwargs["user_id"]) + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: return None - return Transcript(**result) + return Transcript(**row.__dict__) async def get_by_recording_id( - self, recording_id: str, **kwargs + self, session: AsyncSession, recording_id: str, **kwargs ) -> Transcript | None: """ Get a transcript by recording_id """ - query = transcripts.select().where(transcripts.c.recording_id == recording_id) + query = select(TranscriptModel).where( + TranscriptModel.recording_id == recording_id + ) if "user_id" in kwargs: - query = query.where(transcripts.c.user_id == kwargs["user_id"]) - result = await get_database().fetch_one(query) - if not result: + query = query.where(TranscriptModel.user_id == kwargs["user_id"]) + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: return None - return Transcript(**result) + return Transcript(**row.__dict__) - async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]: + async def get_by_room_id( + self, session: AsyncSession, room_id: str, **kwargs + ) -> list[Transcript]: """ Get transcripts by room_id (direct access without joins) """ - query = transcripts.select().where(transcripts.c.room_id == room_id) + query = select(TranscriptModel).where(TranscriptModel.room_id == room_id) if "user_id" in kwargs: - query = query.where(transcripts.c.user_id == kwargs["user_id"]) + query = query.where(TranscriptModel.user_id == kwargs["user_id"]) if "order_by" in kwargs: order_by = kwargs["order_by"] - field = getattr(transcripts.c, order_by[1:]) + field = getattr(TranscriptModel, order_by[1:]) if order_by.startswith("-"): field = field.desc() query = query.order_by(field) - results = await get_database().fetch_all(query) - return [Transcript(**result) for result in results] + results = await session.execute(query) + return [Transcript(**dict(row)) for row in results.mappings().all()] async def get_by_id_for_http( self, + session: AsyncSession, transcript_id: str, user_id: str | None, ) -> Transcript: @@ -491,13 +413,14 @@ class TranscriptController: This method checks the share mode of the transcript and the user_id to determine if the user can access the transcript. """ - query = transcripts.select().where(transcripts.c.id == transcript_id) - result = await get_database().fetch_one(query) - if not result: + query = select(TranscriptModel).where(TranscriptModel.id == transcript_id) + result = await session.execute(query) + row = result.scalar_one_or_none() + if not row: raise HTTPException(status_code=404, detail="Transcript not found") # if the transcript is anonymous, share mode is not checked - transcript = Transcript(**result) + transcript = Transcript(**row.__dict__) if transcript.user_id is None: return transcript @@ -520,6 +443,7 @@ class TranscriptController: async def add( self, + session: AsyncSession, name: str, source_kind: SourceKind, source_language: str = "en", @@ -544,14 +468,15 @@ class TranscriptController: meeting_id=meeting_id, room_id=room_id, ) - query = transcripts.insert().values(**transcript.model_dump()) - await get_database().execute(query) + query = insert(TranscriptModel).values(**transcript.model_dump()) + await session.execute(query) + await session.commit() return transcript # TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates. # using mutate=True is discouraged async def update( - self, transcript: Transcript, values: dict, mutate=False + self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False ) -> Transcript: """ Update a transcript fields with key/values in values. @@ -560,11 +485,12 @@ class TranscriptController: values = TranscriptController._handle_topics_update(values) query = ( - transcripts.update() - .where(transcripts.c.id == transcript.id) + update(TranscriptModel) + .where(TranscriptModel.id == transcript.id) .values(**values) ) - await get_database().execute(query) + await session.execute(query) + await session.commit() if mutate: for key, value in values.items(): setattr(transcript, key, value) @@ -593,13 +519,14 @@ class TranscriptController: async def remove_by_id( self, + session: AsyncSession, transcript_id: str, user_id: str | None = None, ) -> None: """ Remove a transcript by id """ - transcript = await self.get_by_id(transcript_id) + transcript = await self.get_by_id(session, transcript_id) if not transcript: return if user_id is not None and transcript.user_id != user_id: @@ -619,7 +546,7 @@ class TranscriptController: if transcript.recording_id: try: recording = await recordings_controller.get_by_id( - transcript.recording_id + session, transcript.recording_id ) if recording: try: @@ -630,33 +557,40 @@ class TranscriptController: exc_info=e, recording_id=transcript.recording_id, ) - await recordings_controller.remove_by_id(transcript.recording_id) + await recordings_controller.remove_by_id( + session, transcript.recording_id + ) except Exception as e: logger.warning( "Failed to delete recording row", exc_info=e, recording_id=transcript.recording_id, ) - query = transcripts.delete().where(transcripts.c.id == transcript_id) - await get_database().execute(query) + query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id) + await session.execute(query) + await session.commit() - async def remove_by_recording_id(self, recording_id: str): + async def remove_by_recording_id(self, session: AsyncSession, recording_id: str): """ Remove a transcript by recording_id """ - query = transcripts.delete().where(transcripts.c.recording_id == recording_id) - await get_database().execute(query) + query = delete(TranscriptModel).where( + TranscriptModel.recording_id == recording_id + ) + await session.execute(query) + await session.commit() @asynccontextmanager - async def transaction(self): + async def transaction(self, session: AsyncSession): """ A context manager for database transaction """ - async with get_database().transaction(isolation="serializable"): + async with session.begin(): yield async def append_event( self, + session: AsyncSession, transcript: Transcript, event: str, data: Any, @@ -665,11 +599,12 @@ class TranscriptController: Append an event to a transcript """ resp = transcript.add_event(event=event, data=data) - await self.update(transcript, {"events": transcript.events_dump()}) + await self.update(session, transcript, {"events": transcript.events_dump()}) return resp async def upsert_topic( self, + session: AsyncSession, transcript: Transcript, topic: TranscriptTopic, ) -> TranscriptEvent: @@ -677,9 +612,9 @@ class TranscriptController: Upsert topics to a transcript """ transcript.upsert_topic(topic) - await self.update(transcript, {"topics": transcript.topics_dump()}) + await self.update(session, transcript, {"topics": transcript.topics_dump()}) - async def move_mp3_to_storage(self, transcript: Transcript): + async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript): """ Move mp3 file to storage """ @@ -703,12 +638,16 @@ class TranscriptController: # indicate on the transcript that the audio is now on storage # mutates transcript argument - await self.update(transcript, {"audio_location": "storage"}, mutate=True) + await self.update( + session, transcript, {"audio_location": "storage"}, mutate=True + ) # unlink the local file transcript.audio_mp3_filename.unlink(missing_ok=True) - async def download_mp3_from_storage(self, transcript: Transcript): + async def download_mp3_from_storage( + self, session: AsyncSession, transcript: Transcript + ): """ Download audio from storage """ @@ -720,6 +659,7 @@ class TranscriptController: async def upsert_participant( self, + session: AsyncSession, transcript: Transcript, participant: TranscriptParticipant, ) -> TranscriptParticipant: @@ -727,11 +667,14 @@ class TranscriptController: Add/update a participant to a transcript """ result = transcript.upsert_participant(participant) - await self.update(transcript, {"participants": transcript.participants_dump()}) + await self.update( + session, transcript, {"participants": transcript.participants_dump()} + ) return result async def delete_participant( self, + session: AsyncSession, transcript: Transcript, participant_id: str, ): @@ -739,28 +682,31 @@ class TranscriptController: Delete a participant from a transcript """ transcript.delete_participant(participant_id) - await self.update(transcript, {"participants": transcript.participants_dump()}) + await self.update( + session, transcript, {"participants": transcript.participants_dump()} + ) async def set_status( - self, transcript_id: str, status: TranscriptStatus + self, session: AsyncSession, transcript_id: str, status: TranscriptStatus ) -> TranscriptEvent | None: """ Update the status of a transcript Will add an event STATUS + update the status field of transcript """ - async with self.transaction(): - transcript = await self.get_by_id(transcript_id) + async with self.transaction(session): + transcript = await self.get_by_id(session, transcript_id) if not transcript: raise Exception(f"Transcript {transcript_id} not found") if transcript.status == status: return resp = await self.append_event( + session, transcript=transcript, event="STATUS", data=StrValue(value=status), ) - await self.update(transcript, {"status": status}) + await self.update(session, transcript, {"status": status}) return resp diff --git a/server/reflector/db/utils.py b/server/reflector/db/utils.py deleted file mode 100644 index 5cc66e25..00000000 --- a/server/reflector/db/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Database utility functions.""" - -from reflector.db import get_database - - -def is_postgresql() -> bool: - return get_database().url.scheme and get_database().url.scheme.startswith( - "postgresql" - ) diff --git a/server/reflector/views/rooms.py b/server/reflector/views/rooms.py index b849ae3d..e470ab8b 100644 --- a/server/reflector/views/rooms.py +++ b/server/reflector/views/rooms.py @@ -5,12 +5,12 @@ from typing import Annotated, Any, Literal, Optional from fastapi import APIRouter, Depends, HTTPException from fastapi_pagination import Page -from fastapi_pagination.ext.databases import apaginate +from fastapi_pagination.ext.sqlalchemy import paginate from pydantic import BaseModel from redis.exceptions import LockError import reflector.auth as auth -from reflector.db import get_database +from reflector.db import get_session_factory from reflector.db.calendar_events import calendar_events_controller from reflector.db.meetings import meetings_controller from reflector.db.rooms import rooms_controller @@ -182,12 +182,12 @@ async def rooms_list( user_id = user["sub"] if user else None - return await apaginate( - get_database(), - await rooms_controller.get_all( + session_factory = get_session_factory() + async with session_factory() as session: + query = await rooms_controller.get_all( user_id=user_id, order_by="-created_at", return_query=True - ), - ) + ) + return await paginate(session, query) @router.get("/rooms/{room_id}", response_model=RoomDetails) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index ed2445ae..04f647d6 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -3,12 +3,13 @@ from typing import Annotated, Literal, Optional from fastapi import APIRouter, Depends, HTTPException, Query from fastapi_pagination import Page -from fastapi_pagination.ext.databases import apaginate +from fastapi_pagination.ext.sqlalchemy import paginate from jose import jwt from pydantic import BaseModel, Field, constr, field_serializer +from sqlalchemy.ext.asyncio import AsyncSession import reflector.auth as auth -from reflector.db import get_database +from reflector.db import get_session from reflector.db.meetings import meetings_controller from reflector.db.rooms import rooms_controller from reflector.db.search import ( @@ -149,24 +150,25 @@ async def transcripts_list( source_kind: SourceKind | None = None, room_id: str | None = None, search_term: str | None = None, + session: AsyncSession = Depends(get_session), ): if not user and not settings.PUBLIC_MODE: raise HTTPException(status_code=401, detail="Not authenticated") user_id = user["sub"] if user else None - return await apaginate( - get_database(), - await transcripts_controller.get_all( - user_id=user_id, - source_kind=SourceKind(source_kind) if source_kind else None, - room_id=room_id, - search_term=search_term, - order_by="-created_at", - return_query=True, - ), + query = await transcripts_controller.get_all( + session, + user_id=user_id, + source_kind=SourceKind(source_kind) if source_kind else None, + room_id=room_id, + search_term=search_term, + order_by="-created_at", + return_query=True, ) + return await paginate(session, query) + @router.get("/transcripts/search", response_model=SearchResponse) async def transcripts_search( @@ -178,6 +180,7 @@ async def transcripts_search( user: Annotated[ Optional[auth.UserInfo], Depends(auth.current_user_optional) ] = None, + session: AsyncSession = Depends(get_session), ): """ Full-text search across transcript titles and content. @@ -196,7 +199,7 @@ async def transcripts_search( source_kind=source_kind, ) - results, total = await search_controller.search_transcripts(search_params) + results, total = await search_controller.search_transcripts(session, search_params) return SearchResponse( results=results, @@ -211,9 +214,11 @@ async def transcripts_search( async def transcripts_create( info: CreateTranscript, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None return await transcripts_controller.add( + session, info.name, source_kind=info.source_kind or SourceKind.LIVE, source_language=info.source_language, @@ -333,10 +338,11 @@ class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic): async def transcript_get( transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None return await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id + session, transcript_id, user_id=user_id ) @@ -345,13 +351,16 @@ async def transcript_update( transcript_id: str, info: UpdateTranscript, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id + session, transcript_id, user_id=user_id ) values = info.dict(exclude_unset=True) - updated_transcript = await transcripts_controller.update(transcript, values) + updated_transcript = await transcripts_controller.update( + session, transcript, values + ) return updated_transcript @@ -359,19 +368,20 @@ async def transcript_update( async def transcript_delete( transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id) + transcript = await transcripts_controller.get_by_id(session, transcript_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") if transcript.meeting_id: - meeting = await meetings_controller.get_by_id(transcript.meeting_id) - room = await rooms_controller.get_by_id(meeting.room_id) + meeting = await meetings_controller.get_by_id(session, transcript.meeting_id) + room = await rooms_controller.get_by_id(session, meeting.room_id) if room.is_shared: user_id = None - await transcripts_controller.remove_by_id(transcript.id, user_id=user_id) + await transcripts_controller.remove_by_id(session, transcript.id, user_id=user_id) return DeletionStatus(status="ok") @@ -382,10 +392,11 @@ async def transcript_delete( async def transcript_get_topics( transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id + session, transcript_id, user_id=user_id ) # convert to GetTranscriptTopic @@ -401,10 +412,11 @@ async def transcript_get_topics( async def transcript_get_topics_with_words( transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id + session, transcript_id, user_id=user_id ) # convert to GetTranscriptTopicWithWords @@ -422,10 +434,11 @@ async def transcript_get_topics_with_words_per_speaker( transcript_id: str, topic_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id + session, transcript_id, user_id=user_id ) # get the topic from the transcript @@ -444,10 +457,11 @@ async def transcript_post_to_zulip( topic: str, include_topics: bool, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + session: AsyncSession = Depends(get_session), ): user_id = user["sub"] if user else None transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id + session, transcript_id, user_id=user_id ) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") @@ -467,5 +481,5 @@ async def transcript_post_to_zulip( if not message_updated: response = await send_message_to_zulip(stream, topic, content) await transcripts_controller.update( - transcript, {"zulip_message_id": response["id"]} + session, transcript, {"zulip_message_id": response["id"]} ) diff --git a/server/reflector/worker/cleanup.py b/server/reflector/worker/cleanup.py index e634994d..a34f7c75 100644 --- a/server/reflector/worker/cleanup.py +++ b/server/reflector/worker/cleanup.py @@ -11,14 +11,13 @@ from typing import TypedDict import structlog from celery import shared_task -from databases import Database from pydantic.types import PositiveInt +from sqlalchemy import delete, select from reflector.asynctask import asynctask -from reflector.db import get_database -from reflector.db.meetings import meetings -from reflector.db.recordings import recordings -from reflector.db.transcripts import transcripts, transcripts_controller +from reflector.db import get_session_factory +from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel +from reflector.db.transcripts import transcripts_controller from reflector.settings import settings from reflector.storage import get_recordings_storage @@ -35,43 +34,49 @@ class CleanupStats(TypedDict): async def delete_single_transcript( - db: Database, transcript_data: dict, stats: CleanupStats + session_factory, transcript_data: dict, stats: CleanupStats ): transcript_id = transcript_data["id"] meeting_id = transcript_data["meeting_id"] recording_id = transcript_data["recording_id"] try: - async with db.transaction(isolation="serializable"): - if meeting_id: - await db.execute(meetings.delete().where(meetings.c.id == meeting_id)) - stats["meetings_deleted"] += 1 - logger.info("Deleted associated meeting", meeting_id=meeting_id) - - if recording_id: - recording = await db.fetch_one( - recordings.select().where(recordings.c.id == recording_id) - ) - if recording: - try: - await get_recordings_storage().delete_file( - recording["object_key"] - ) - except Exception as storage_error: - logger.warning( - "Failed to delete recording from storage", - recording_id=recording_id, - object_key=recording["object_key"], - error=str(storage_error), - ) - - await db.execute( - recordings.delete().where(recordings.c.id == recording_id) + async with session_factory() as session: + async with session.begin(): + if meeting_id: + await session.execute( + delete(MeetingModel).where(MeetingModel.id == meeting_id) ) - stats["recordings_deleted"] += 1 - logger.info( - "Deleted associated recording", recording_id=recording_id + stats["meetings_deleted"] += 1 + logger.info("Deleted associated meeting", meeting_id=meeting_id) + + if recording_id: + result = await session.execute( + select(RecordingModel).where(RecordingModel.id == recording_id) ) + recording = result.mappings().first() + if recording: + try: + await get_recordings_storage().delete_file( + recording["object_key"] + ) + except Exception as storage_error: + logger.warning( + "Failed to delete recording from storage", + recording_id=recording_id, + object_key=recording["object_key"], + error=str(storage_error), + ) + + await session.execute( + delete(RecordingModel).where( + RecordingModel.id == recording_id + ) + ) + stats["recordings_deleted"] += 1 + logger.info( + "Deleted associated recording", recording_id=recording_id + ) await transcripts_controller.remove_by_id(transcript_id) stats["transcripts_deleted"] += 1 @@ -87,18 +92,21 @@ async def delete_single_transcript( async def cleanup_old_transcripts( - db: Database, cutoff_date: datetime, stats: CleanupStats + session_factory, cutoff_date: datetime, stats: CleanupStats ): """Delete old anonymous transcripts and their associated recordings/meetings.""" - query = transcripts.select().where( - (transcripts.c.created_at < cutoff_date) & (transcripts.c.user_id.is_(None)) + query = select(transcripts).where( + (TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None)) ) - old_transcripts = await db.fetch_all(query) + + async with session_factory() as session: + result = await session.execute(query) + old_transcripts = result.mappings().all() logger.info(f"Found {len(old_transcripts)} old transcripts to delete") for transcript_data in old_transcripts: - await delete_single_transcript(db, transcript_data, stats) + await delete_single_transcript(session_factory, transcript_data, stats) def log_cleanup_results(stats: CleanupStats): @@ -140,8 +148,8 @@ async def cleanup_old_public_data( "errors": [], } - db = get_database() - await cleanup_old_transcripts(db, cutoff_date, stats) + session_factory = get_session_factory() + await cleanup_old_transcripts(session_factory, cutoff_date, stats) log_cleanup_results(stats) return stats diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 22fe4193..1f8c1ff4 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -69,17 +69,19 @@ def postgres_service(docker_ip, docker_services): @pytest.fixture(scope="function", autouse=True) @pytest.mark.asyncio async def setup_database(postgres_service): - from reflector.db import engine, metadata, get_database # noqa + from reflector.db import get_engine + from reflector.db.base import metadata - metadata.drop_all(bind=engine) - metadata.create_all(bind=engine) - database = get_database() + async_engine = get_engine() + + async with async_engine.begin() as conn: + await conn.run_sync(metadata.drop_all) + await conn.run_sync(metadata.create_all) try: - await database.connect() yield finally: - await database.disconnect() + await async_engine.dispose() @pytest.fixture diff --git a/server/tests/test_room_ics.py b/server/tests/test_room_ics.py index 7a3c4d74..8198ece7 100644 --- a/server/tests/test_room_ics.py +++ b/server/tests/test_room_ics.py @@ -196,9 +196,9 @@ async def test_room_list_with_ics_enabled_filter(): assert len(all_rooms) == 3 # Filter for ICS-enabled rooms (would need to implement this in controller) - ics_rooms = [r for r in all_rooms if r["ics_enabled"]] + ics_rooms = [r for r in all_rooms if r.ics_enabled] assert len(ics_rooms) == 2 - assert all(r["ics_enabled"] for r in ics_rooms) + assert all(r.ics_enabled for r in ics_rooms) @pytest.mark.asyncio diff --git a/server/uv.lock b/server/uv.lock index 2c28f61b..c7c3f08f 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -763,26 +763,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, ] -[[package]] -name = "databases" -version = "0.8.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "sqlalchemy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7e/73/a8e49fa9ae156249e86474a4dc461a95e6e389dc0f139ff4c798a5130e8d/databases-0.8.0.tar.gz", hash = "sha256:6544d82e9926f233d694ec29cd018403444c7fb6e863af881a8304d1ff5cfb90", size = 27569, upload-time = "2023-08-28T14:51:43.533Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/09/18/af04845ee14bf9183a7403e109ce7ac9468403bf0b8f8ff4802cc8bf0d79/databases-0.8.0-py3-none-any.whl", hash = "sha256:0ceb7fd5c740d846e1f4f58c0256d780a6786841ec8e624a21f1eb1b51a9093d", size = 22560, upload-time = "2023-08-28T14:51:41.734Z" }, -] - -[package.optional-dependencies] -aiosqlite = [ - { name = "aiosqlite" }, -] -asyncpg = [ - { name = "asyncpg" }, -] - [[package]] name = "dataclasses-json" version = "0.6.7" @@ -3100,9 +3080,9 @@ dependencies = [ { name = "aiohttp-cors" }, { name = "aiortc" }, { name = "alembic" }, + { name = "asyncpg" }, { name = "av" }, { name = "celery" }, - { name = "databases", extra = ["aiosqlite", "asyncpg"] }, { name = "fastapi", extra = ["standard"] }, { name = "fastapi-pagination" }, { name = "httpx" }, @@ -3176,9 +3156,9 @@ requires-dist = [ { name = "aiohttp-cors", specifier = ">=0.7.0" }, { name = "aiortc", specifier = ">=1.5.0" }, { name = "alembic", specifier = ">=1.11.3" }, + { name = "asyncpg", specifier = ">=0.29.0" }, { name = "av", specifier = ">=10.0.0" }, { name = "celery", specifier = ">=5.3.4" }, - { name = "databases", extras = ["aiosqlite", "asyncpg"], specifier = ">=0.7.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.100.1" }, { name = "fastapi-pagination", specifier = ">=0.12.6" }, { name = "httpx", specifier = ">=0.24.1" }, @@ -3200,7 +3180,7 @@ requires-dist = [ { name = "sentencepiece", specifier = ">=0.1.99" }, { name = "sentry-sdk", extras = ["fastapi"], specifier = ">=1.29.2" }, { name = "sortedcontainers", specifier = ">=2.4.0" }, - { name = "sqlalchemy", specifier = "<1.5" }, + { name = "sqlalchemy", specifier = ">=2.0.0" }, { name = "structlog", specifier = ">=23.1.0" }, { name = "transformers", specifier = ">=4.36.2" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.23.1" }, @@ -3717,23 +3697,31 @@ wheels = [ [[package]] name = "sqlalchemy" -version = "1.4.54" +version = "2.0.43" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ce/af/20290b55d469e873cba9d41c0206ab5461ff49d759989b3fe65010f9d265/sqlalchemy-1.4.54.tar.gz", hash = "sha256:4470fbed088c35dc20b78a39aaf4ae54fe81790c783b3264872a0224f437c31a", size = 8470350, upload-time = "2024-09-05T15:54:10.398Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/bc/d59b5d97d27229b0e009bd9098cd81af71c2fa5549c580a0a67b9bed0496/sqlalchemy-2.0.43.tar.gz", hash = "sha256:788bfcef6787a7764169cfe9859fe425bf44559619e1d9f56f5bddf2ebf6f417", size = 9762949, upload-time = "2025-08-11T14:24:58.438Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/49/fb98983b5568e93696a25fd5bec1b789095b79a72d5f57c6effddaa81d0a/SQLAlchemy-1.4.54-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b05e0626ec1c391432eabb47a8abd3bf199fb74bfde7cc44a26d2b1b352c2c6e", size = 1589301, upload-time = "2024-09-05T19:22:42.197Z" }, - { url = "https://files.pythonhosted.org/packages/03/98/5a81430bbd646991346cb088a2bdc84d1bcd3dbe6b0cfc1aaa898370e5c7/SQLAlchemy-1.4.54-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13e91d6892b5fcb94a36ba061fb7a1f03d0185ed9d8a77c84ba389e5bb05e936", size = 1629553, upload-time = "2024-09-05T17:49:18.846Z" }, - { url = "https://files.pythonhosted.org/packages/f1/17/14e35db2b0d6deaa27691d014addbb0dd6f7e044f7ee465446a3c0c71404/SQLAlchemy-1.4.54-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb59a11689ff3c58e7652260127f9e34f7f45478a2f3ef831ab6db7bcd72108f", size = 1627640, upload-time = "2024-09-05T17:48:01.558Z" }, - { url = "https://files.pythonhosted.org/packages/98/62/335006a8f2c98f704f391e1a0cc01446d1b1b9c198f579f03599f55bd860/SQLAlchemy-1.4.54-cp311-cp311-win32.whl", hash = "sha256:1390ca2d301a2708fd4425c6d75528d22f26b8f5cbc9faba1ddca136671432bc", size = 1591723, upload-time = "2024-09-05T17:53:17.486Z" }, - { url = "https://files.pythonhosted.org/packages/e2/a1/6b4b8c07082920f5445ec65c221fa33baab102aced5dcc2d87a15d3f8db4/SQLAlchemy-1.4.54-cp311-cp311-win_amd64.whl", hash = "sha256:2b37931eac4b837c45e2522066bda221ac6d80e78922fb77c75eb12e4dbcdee5", size = 1593511, upload-time = "2024-09-05T17:51:50.947Z" }, - { url = "https://files.pythonhosted.org/packages/a5/1b/aa9b99be95d1615f058b5827447c18505b7b3f1dfcbd6ce1b331c2107152/SQLAlchemy-1.4.54-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3f01c2629a7d6b30d8afe0326b8c649b74825a0e1ebdcb01e8ffd1c920deb07d", size = 1589983, upload-time = "2024-09-05T17:39:02.132Z" }, - { url = "https://files.pythonhosted.org/packages/59/47/cb0fc64e5344f0a3d02216796c342525ab283f8f052d1c31a1d487d08aa0/SQLAlchemy-1.4.54-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c24dd161c06992ed16c5e528a75878edbaeced5660c3db88c820f1f0d3fe1f4", size = 1630158, upload-time = "2024-09-05T17:50:13.255Z" }, - { url = "https://files.pythonhosted.org/packages/c0/8b/f45dd378f6c97e8ff9332ff3d03ecb0b8c491be5bb7a698783b5a2f358ec/SQLAlchemy-1.4.54-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5e0d47d619c739bdc636bbe007da4519fc953393304a5943e0b5aec96c9877c", size = 1629232, upload-time = "2024-09-05T17:48:15.514Z" }, - { url = "https://files.pythonhosted.org/packages/0d/3c/884fe389f5bec86a310b81e79abaa1e26e5d78dc10a84d544a6822833e47/SQLAlchemy-1.4.54-cp312-cp312-win32.whl", hash = "sha256:12bc0141b245918b80d9d17eca94663dbd3f5266ac77a0be60750f36102bbb0f", size = 1592027, upload-time = "2024-09-05T17:54:02.253Z" }, - { url = "https://files.pythonhosted.org/packages/01/c3/c690d037be57efd3a69cde16a2ef1bd2a905dafe869434d33836de0983d0/SQLAlchemy-1.4.54-cp312-cp312-win_amd64.whl", hash = "sha256:f941aaf15f47f316123e1933f9ea91a6efda73a161a6ab6046d1cde37be62c88", size = 1593827, upload-time = "2024-09-05T17:52:07.454Z" }, + { url = "https://files.pythonhosted.org/packages/9d/77/fa7189fe44114658002566c6fe443d3ed0ec1fa782feb72af6ef7fbe98e7/sqlalchemy-2.0.43-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52d9b73b8fb3e9da34c2b31e6d99d60f5f99fd8c1225c9dad24aeb74a91e1d29", size = 2136472, upload-time = "2025-08-11T15:52:21.789Z" }, + { url = "https://files.pythonhosted.org/packages/99/ea/92ac27f2fbc2e6c1766bb807084ca455265707e041ba027c09c17d697867/sqlalchemy-2.0.43-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f42f23e152e4545157fa367b2435a1ace7571cab016ca26038867eb7df2c3631", size = 2126535, upload-time = "2025-08-11T15:52:23.109Z" }, + { url = "https://files.pythonhosted.org/packages/94/12/536ede80163e295dc57fff69724caf68f91bb40578b6ac6583a293534849/sqlalchemy-2.0.43-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4fb1a8c5438e0c5ea51afe9c6564f951525795cf432bed0c028c1cb081276685", size = 3297521, upload-time = "2025-08-11T15:50:33.536Z" }, + { url = "https://files.pythonhosted.org/packages/03/b5/cacf432e6f1fc9d156eca0560ac61d4355d2181e751ba8c0cd9cb232c8c1/sqlalchemy-2.0.43-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db691fa174e8f7036afefe3061bc40ac2b770718be2862bfb03aabae09051aca", size = 3297343, upload-time = "2025-08-11T15:57:51.186Z" }, + { url = "https://files.pythonhosted.org/packages/ca/ba/d4c9b526f18457667de4c024ffbc3a0920c34237b9e9dd298e44c7c00ee5/sqlalchemy-2.0.43-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe2b3b4927d0bc03d02ad883f402d5de201dbc8894ac87d2e981e7d87430e60d", size = 3232113, upload-time = "2025-08-11T15:50:34.949Z" }, + { url = "https://files.pythonhosted.org/packages/aa/79/c0121b12b1b114e2c8a10ea297a8a6d5367bc59081b2be896815154b1163/sqlalchemy-2.0.43-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4d3d9b904ad4a6b175a2de0738248822f5ac410f52c2fd389ada0b5262d6a1e3", size = 3258240, upload-time = "2025-08-11T15:57:52.983Z" }, + { url = "https://files.pythonhosted.org/packages/79/99/a2f9be96fb382f3ba027ad42f00dbe30fdb6ba28cda5f11412eee346bec5/sqlalchemy-2.0.43-cp311-cp311-win32.whl", hash = "sha256:5cda6b51faff2639296e276591808c1726c4a77929cfaa0f514f30a5f6156921", size = 2101248, upload-time = "2025-08-11T15:55:01.855Z" }, + { url = "https://files.pythonhosted.org/packages/ee/13/744a32ebe3b4a7a9c7ea4e57babae7aa22070d47acf330d8e5a1359607f1/sqlalchemy-2.0.43-cp311-cp311-win_amd64.whl", hash = "sha256:c5d1730b25d9a07727d20ad74bc1039bbbb0a6ca24e6769861c1aa5bf2c4c4a8", size = 2126109, upload-time = "2025-08-11T15:55:04.092Z" }, + { url = "https://files.pythonhosted.org/packages/61/db/20c78f1081446095450bdc6ee6cc10045fce67a8e003a5876b6eaafc5cc4/sqlalchemy-2.0.43-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:20d81fc2736509d7a2bd33292e489b056cbae543661bb7de7ce9f1c0cd6e7f24", size = 2134891, upload-time = "2025-08-11T15:51:13.019Z" }, + { url = "https://files.pythonhosted.org/packages/45/0a/3d89034ae62b200b4396f0f95319f7d86e9945ee64d2343dcad857150fa2/sqlalchemy-2.0.43-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b9fc27650ff5a2c9d490c13c14906b918b0de1f8fcbb4c992712d8caf40e83", size = 2123061, upload-time = "2025-08-11T15:51:14.319Z" }, + { url = "https://files.pythonhosted.org/packages/cb/10/2711f7ff1805919221ad5bee205971254845c069ee2e7036847103ca1e4c/sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6772e3ca8a43a65a37c88e2f3e2adfd511b0b1da37ef11ed78dea16aeae85bd9", size = 3320384, upload-time = "2025-08-11T15:52:35.088Z" }, + { url = "https://files.pythonhosted.org/packages/6e/0e/3d155e264d2ed2778484006ef04647bc63f55b3e2d12e6a4f787747b5900/sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a113da919c25f7f641ffbd07fbc9077abd4b3b75097c888ab818f962707eb48", size = 3329648, upload-time = "2025-08-11T15:56:34.153Z" }, + { url = "https://files.pythonhosted.org/packages/5b/81/635100fb19725c931622c673900da5efb1595c96ff5b441e07e3dd61f2be/sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4286a1139f14b7d70141c67a8ae1582fc2b69105f1b09d9573494eb4bb4b2687", size = 3258030, upload-time = "2025-08-11T15:52:36.933Z" }, + { url = "https://files.pythonhosted.org/packages/0c/ed/a99302716d62b4965fded12520c1cbb189f99b17a6d8cf77611d21442e47/sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:529064085be2f4d8a6e5fab12d36ad44f1909a18848fcfbdb59cc6d4bbe48efe", size = 3294469, upload-time = "2025-08-11T15:56:35.553Z" }, + { url = "https://files.pythonhosted.org/packages/5d/a2/3a11b06715149bf3310b55a98b5c1e84a42cfb949a7b800bc75cb4e33abc/sqlalchemy-2.0.43-cp312-cp312-win32.whl", hash = "sha256:b535d35dea8bbb8195e7e2b40059e2253acb2b7579b73c1b432a35363694641d", size = 2098906, upload-time = "2025-08-11T15:55:00.645Z" }, + { url = "https://files.pythonhosted.org/packages/bc/09/405c915a974814b90aa591280623adc6ad6b322f61fd5cff80aeaef216c9/sqlalchemy-2.0.43-cp312-cp312-win_amd64.whl", hash = "sha256:1c6d85327ca688dbae7e2b06d7d84cfe4f3fffa5b5f9e21bb6ce9d0e1a0e0e0a", size = 2126260, upload-time = "2025-08-11T15:55:02.965Z" }, + { url = "https://files.pythonhosted.org/packages/b8/d9/13bdde6521f322861fab67473cec4b1cc8999f3871953531cf61945fad92/sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc", size = 1924759, upload-time = "2025-08-11T15:39:53.024Z" }, ] [package.optional-dependencies]