mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
feat: postgresql migration and removal of sqlite in pytest (#546)
* feat: remove support of sqlite, 100% postgres * fix: more migration and make datetime timezone aware in postgres * fix: change how database is get, and use contextvar to have difference instance between different loops * test: properly use client fixture that handle lifetime/database connection * fix: add missing client fixture parameters to test functions This commit fixes NameError issues where test functions were trying to use the 'client' fixture but didn't have it as a parameter. The changes include: 1. Added 'client' parameter to test functions in: - test_transcripts_audio_download.py (6 functions including fixture) - test_transcripts_speaker.py (3 functions) - test_transcripts_upload.py (1 function) - test_transcripts_rtc_ws.py (2 functions + appserver fixture) 2. Resolved naming conflicts in test_transcripts_rtc_ws.py where both HTTP client and StreamClient were using variable name 'client'. StreamClient instances are now named 'stream_client' to avoid conflicts. 3. Added missing 'from reflector.app import app' import in rtc_ws tests. Background: Previously implemented contextvars solution with get_database() function resolves asyncio event loop conflicts in Celery tasks. The global client fixture was also created to replace manual AsyncClient instances, ensuring proper FastAPI application lifecycle management and database connections during tests. All tests now pass except for 2 pre-existing RTC WebSocket test failures related to asyncpg connection issues unrelated to these fixes. * fix: ensure task are correctly closed * fix: make separate event loop for the live server * fix: make default settings pointing at postgres * build: remove pytest-docker deps out of dev, just tests group
This commit is contained in:
@@ -1,12 +1,28 @@
|
||||
import contextvars
|
||||
from typing import Optional
|
||||
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
database = databases.Database(settings.DATABASE_URL)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
|
||||
contextvars.ContextVar("database", default=None)
|
||||
)
|
||||
|
||||
|
||||
def get_database() -> databases.Database:
|
||||
"""Get database instance for current asyncio context"""
|
||||
db = _database_context.get()
|
||||
if db is None:
|
||||
db = databases.Database(settings.DATABASE_URL)
|
||||
_database_context.set(db)
|
||||
return db
|
||||
|
||||
|
||||
# import models
|
||||
import reflector.db.meetings # noqa
|
||||
import reflector.db.recordings # noqa
|
||||
@@ -14,16 +30,18 @@ import reflector.db.rooms # noqa
|
||||
import reflector.db.transcripts # noqa
|
||||
|
||||
kwargs = {}
|
||||
if "sqlite" in settings.DATABASE_URL:
|
||||
kwargs["connect_args"] = {"check_same_thread": False}
|
||||
if "postgres" not in settings.DATABASE_URL:
|
||||
raise Exception("Only postgres database is supported in reflector")
|
||||
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
|
||||
|
||||
|
||||
@subscribers_startup.append
|
||||
async def database_connect(_):
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
|
||||
|
||||
@subscribers_shutdown.append
|
||||
async def database_disconnect(_):
|
||||
database = get_database()
|
||||
await database.disconnect()
|
||||
|
||||
@@ -5,7 +5,7 @@ import sqlalchemy as sa
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.rooms import Room
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
@@ -16,8 +16,8 @@ meetings = sa.Table(
|
||||
sa.Column("room_name", sa.String),
|
||||
sa.Column("room_url", sa.String),
|
||||
sa.Column("host_room_url", sa.String),
|
||||
sa.Column("start_date", sa.DateTime),
|
||||
sa.Column("end_date", sa.DateTime),
|
||||
sa.Column("start_date", sa.DateTime(timezone=True)),
|
||||
sa.Column("end_date", sa.DateTime(timezone=True)),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("room_id", sa.String),
|
||||
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||
@@ -42,6 +42,12 @@ meetings = sa.Table(
|
||||
server_default=sa.true(),
|
||||
),
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"room_id",
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
),
|
||||
)
|
||||
|
||||
meeting_consent = sa.Table(
|
||||
@@ -51,7 +57,7 @@ meeting_consent = sa.Table(
|
||||
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
|
||||
sa.Column("user_id", sa.String),
|
||||
sa.Column("consent_given", sa.Boolean, nullable=False),
|
||||
sa.Column("consent_timestamp", sa.DateTime, nullable=False),
|
||||
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
@@ -111,7 +117,7 @@ class MeetingController:
|
||||
recording_trigger=room.recording_trigger,
|
||||
)
|
||||
query = meetings.insert().values(**meeting.model_dump())
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
return meeting
|
||||
|
||||
async def get_all_active(self) -> list[Meeting]:
|
||||
@@ -119,7 +125,7 @@ class MeetingController:
|
||||
Get active meetings.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.is_active)
|
||||
return await database.fetch_all(query)
|
||||
return await get_database().fetch_all(query)
|
||||
|
||||
async def get_by_room_name(
|
||||
self,
|
||||
@@ -129,7 +135,7 @@ class MeetingController:
|
||||
Get a meeting by room name.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.room_name == room_name)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
@@ -151,7 +157,7 @@ class MeetingController:
|
||||
)
|
||||
.order_by(end_date.desc())
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
@@ -162,7 +168,7 @@ class MeetingController:
|
||||
Get a meeting by id
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Meeting(**result)
|
||||
@@ -174,7 +180,7 @@ class MeetingController:
|
||||
If not found, it will raise a 404 error.
|
||||
"""
|
||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
@@ -186,7 +192,7 @@ class MeetingController:
|
||||
|
||||
async def update_meeting(self, meeting_id: str, **kwargs):
|
||||
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
|
||||
|
||||
class MeetingConsentController:
|
||||
@@ -194,7 +200,7 @@ class MeetingConsentController:
|
||||
query = meeting_consent.select().where(
|
||||
meeting_consent.c.meeting_id == meeting_id
|
||||
)
|
||||
results = await database.fetch_all(query)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [MeetingConsent(**result) for result in results]
|
||||
|
||||
async def get_by_meeting_and_user(
|
||||
@@ -205,7 +211,7 @@ class MeetingConsentController:
|
||||
meeting_consent.c.meeting_id == meeting_id,
|
||||
meeting_consent.c.user_id == user_id,
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if result is None:
|
||||
return None
|
||||
return MeetingConsent(**result) if result else None
|
||||
@@ -227,14 +233,14 @@ class MeetingConsentController:
|
||||
consent_timestamp=consent.consent_timestamp,
|
||||
)
|
||||
)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
|
||||
existing.consent_given = consent.consent_given
|
||||
existing.consent_timestamp = consent.consent_timestamp
|
||||
return existing
|
||||
|
||||
query = meeting_consent.insert().values(**consent.model_dump())
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
return consent
|
||||
|
||||
async def has_any_denial(self, meeting_id: str) -> bool:
|
||||
@@ -243,7 +249,7 @@ class MeetingConsentController:
|
||||
meeting_consent.c.meeting_id == meeting_id,
|
||||
meeting_consent.c.consent_given.is_(False),
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
return result is not None
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Literal
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
recordings = sa.Table(
|
||||
@@ -13,7 +13,7 @@ recordings = sa.Table(
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column("bucket_name", sa.String, nullable=False),
|
||||
sa.Column("object_key", sa.String, nullable=False),
|
||||
sa.Column("recorded_at", sa.DateTime, nullable=False),
|
||||
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String,
|
||||
@@ -37,12 +37,12 @@ class Recording(BaseModel):
|
||||
class RecordingController:
|
||||
async def create(self, recording: Recording):
|
||||
query = recordings.insert().values(**recording.model_dump())
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
return recording
|
||||
|
||||
async def get_by_id(self, id: str) -> Recording:
|
||||
query = recordings.select().where(recordings.c.id == id)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
return Recording(**result) if result else None
|
||||
|
||||
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
|
||||
@@ -50,7 +50,7 @@ class RecordingController:
|
||||
recordings.c.bucket_name == bucket_name,
|
||||
recordings.c.object_key == object_key,
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
return Recording(**result) if result else None
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import Literal
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
rooms = sqlalchemy.Table(
|
||||
@@ -16,7 +16,7 @@ rooms = sqlalchemy.Table(
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime, nullable=False),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||
sqlalchemy.Column(
|
||||
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
@@ -48,7 +48,7 @@ class Room(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
name: str
|
||||
user_id: str
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
zulip_auto_post: bool = False
|
||||
zulip_stream: str = ""
|
||||
zulip_topic: str = ""
|
||||
@@ -92,7 +92,7 @@ class RoomController:
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
results = await get_database().fetch_all(query)
|
||||
return results
|
||||
|
||||
async def add(
|
||||
@@ -125,7 +125,7 @@ class RoomController:
|
||||
)
|
||||
query = rooms.insert().values(**room.model_dump())
|
||||
try:
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
except IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||
return room
|
||||
@@ -136,7 +136,7 @@ class RoomController:
|
||||
"""
|
||||
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
||||
try:
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
except IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||
|
||||
@@ -151,7 +151,7 @@ class RoomController:
|
||||
query = rooms.select().where(rooms.c.id == room_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Room(**result)
|
||||
@@ -163,7 +163,7 @@ class RoomController:
|
||||
query = rooms.select().where(rooms.c.name == room_name)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Room(**result)
|
||||
@@ -175,7 +175,7 @@ class RoomController:
|
||||
If not found, it will raise a 404 error.
|
||||
"""
|
||||
query = rooms.select().where(rooms.c.id == meeting_id)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
@@ -197,7 +197,7 @@ class RoomController:
|
||||
if user_id is not None and room.user_id != user_id:
|
||||
return
|
||||
query = rooms.delete().where(rooms.c.id == room_id)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
|
||||
|
||||
rooms_controller = RoomController()
|
||||
|
||||
@@ -9,7 +9,7 @@ import sqlalchemy
|
||||
import webvtt
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db import get_database
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
|
||||
@@ -207,12 +207,12 @@ class SearchController:
|
||||
.limit(params.limit)
|
||||
.offset(params.offset)
|
||||
)
|
||||
rs = await database.fetch_all(query)
|
||||
rs = await get_database().fetch_all(query)
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
base_query.alias("search_results")
|
||||
)
|
||||
total = await database.fetch_val(count_query)
|
||||
total = await get_database().fetch_val(count_query)
|
||||
|
||||
def _process_result(r) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
|
||||
@@ -15,7 +15,7 @@ from sqlalchemy import Enum
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
@@ -41,7 +41,7 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Float),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
|
||||
sqlalchemy.Column("title", sqlalchemy.String),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
||||
@@ -421,7 +421,7 @@ class TranscriptController:
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
results = await get_database().fetch_all(query)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
@@ -431,7 +431,7 @@ class TranscriptController:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
@@ -445,7 +445,7 @@ class TranscriptController:
|
||||
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
@@ -463,7 +463,7 @@ class TranscriptController:
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
results = await database.fetch_all(query)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [Transcript(**result) for result in results]
|
||||
|
||||
async def get_by_id_for_http(
|
||||
@@ -481,7 +481,7 @@ class TranscriptController:
|
||||
to determine if the user can access the transcript.
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
@@ -534,7 +534,7 @@ class TranscriptController:
|
||||
room_id=room_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
return transcript
|
||||
|
||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||
@@ -553,7 +553,7 @@ class TranscriptController:
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
if mutate:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
@@ -595,21 +595,21 @@ class TranscriptController:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
|
||||
async def remove_by_recording_id(self, recording_id: str):
|
||||
"""
|
||||
Remove a transcript by recording_id
|
||||
"""
|
||||
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""
|
||||
A context manager for database transaction
|
||||
"""
|
||||
async with database.transaction(isolation="serializable"):
|
||||
async with get_database().transaction(isolation="serializable"):
|
||||
yield
|
||||
|
||||
async def append_event(
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Database utility functions."""
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db import get_database
|
||||
|
||||
|
||||
def is_postgresql() -> bool:
|
||||
return database.url.scheme and database.url.scheme.startswith('postgresql')
|
||||
return get_database().url.scheme and get_database().url.scheme.startswith(
|
||||
"postgresql"
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from celery import chord, current_task, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db import get_database
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -72,6 +72,7 @@ def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
async def run_with_db():
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
@@ -6,7 +6,7 @@ This script is used to generate a summary of a meeting notes transcript.
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from textwrap import dedent
|
||||
from typing import Type, TypeVar
|
||||
@@ -474,7 +474,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.save:
|
||||
# write the summary to a file, on the format summary-<iso date>.md
|
||||
filename = f"summary-{datetime.now().isoformat()}.md"
|
||||
filename = f"summary-{datetime.now(timezone.utc).isoformat()}.md"
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(sm.as_markdown())
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@ class Settings(BaseSettings):
|
||||
CORS_ALLOW_CREDENTIALS: bool = False
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
||||
DATABASE_URL: str = (
|
||||
"postgresql+asyncpg://reflector:reflector@localhost:5432/reflector"
|
||||
)
|
||||
|
||||
# local data directory
|
||||
DATA_DIR: str = "./data"
|
||||
|
||||
@@ -9,8 +9,9 @@ async def export_db(filename: str) -> None:
|
||||
filename = pathlib.Path(filename).resolve()
|
||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.db import get_database, transcripts
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
transcripts = await database.fetch_all(transcripts.select())
|
||||
await database.disconnect()
|
||||
|
||||
@@ -8,8 +8,9 @@ async def export_db(filename: str) -> None:
|
||||
filename = pathlib.Path(filename).resolve()
|
||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.db import get_database, transcripts
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
transcripts = await database.fetch_all(transcripts.select())
|
||||
await database.disconnect()
|
||||
|
||||
@@ -155,7 +155,7 @@ async def process_audio_file_with_diarization(
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
@@ -163,7 +163,7 @@ async def process_audio_file_with_diarization(
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
@@ -35,7 +35,7 @@ async def meeting_audio_consent(
|
||||
meeting_id=meeting_id,
|
||||
user_id=user_id,
|
||||
consent_given=request.consent_given,
|
||||
consent_timestamp=datetime.utcnow(),
|
||||
consent_timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
updated_consent = await meeting_consent_controller.upsert(consent)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import asyncpg.exceptions
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import paginate
|
||||
from fastapi_pagination.ext.databases import apaginate
|
||||
from pydantic import BaseModel
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database
|
||||
from reflector.db import get_database
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.settings import settings
|
||||
@@ -21,6 +21,14 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||
dt = datetime.fromisoformat(iso_string)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
class Room(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -83,8 +91,8 @@ async def rooms_list(
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
return await paginate(
|
||||
database,
|
||||
return await apaginate(
|
||||
get_database(),
|
||||
await rooms_controller.get_all(
|
||||
user_id=user_id, order_by="-created_at", return_query=True
|
||||
),
|
||||
@@ -150,7 +158,7 @@ async def rooms_create_meeting(
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
current_time = datetime.now(timezone.utc)
|
||||
meeting = await meetings_controller.get_active(room=room, current_time=current_time)
|
||||
|
||||
if meeting is None:
|
||||
@@ -166,8 +174,8 @@ async def rooms_create_meeting(
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||
start_date=datetime.fromisoformat(whereby_meeting["startDate"]),
|
||||
end_date=datetime.fromisoformat(whereby_meeting["endDate"]),
|
||||
start_date=parse_datetime_with_timezone(whereby_meeting["startDate"]),
|
||||
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
|
||||
user_id=user_id,
|
||||
room=room,
|
||||
)
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Annotated, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import paginate
|
||||
from fastapi_pagination.ext.databases import apaginate
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database
|
||||
from reflector.db import get_database
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.search import (
|
||||
@@ -48,7 +48,7 @@ DOWNLOAD_EXPIRE_MINUTES = 60
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
@@ -141,8 +141,8 @@ async def transcripts_list(
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
return await paginate(
|
||||
database,
|
||||
return await apaginate(
|
||||
get_database(),
|
||||
await transcripts_controller.get_all(
|
||||
user_id=user_id,
|
||||
source_kind=SourceKind(source_kind) if source_kind else None,
|
||||
|
||||
@@ -21,6 +21,14 @@ from reflector.whereby import get_room_sessions
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
|
||||
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||
dt = datetime.fromisoformat(iso_string)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
@shared_task
|
||||
def process_messages():
|
||||
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
|
||||
@@ -69,7 +77,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
|
||||
# extract a guid and a datetime from the object key
|
||||
room_name = f"/{object_key[:36]}"
|
||||
recorded_at = datetime.fromisoformat(object_key[37:57])
|
||||
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||
|
||||
meeting = await meetings_controller.get_by_room_name(room_name)
|
||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||
|
||||
@@ -62,6 +62,7 @@ class RedisPubSubManager:
|
||||
class WebsocketManager:
|
||||
def __init__(self, pubsub_client: RedisPubSubManager = None):
|
||||
self.rooms: dict = {}
|
||||
self.tasks: dict = {}
|
||||
self.pubsub_client = pubsub_client
|
||||
|
||||
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
@@ -74,13 +75,17 @@ class WebsocketManager:
|
||||
|
||||
await self.pubsub_client.connect()
|
||||
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
||||
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
||||
task = asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
||||
self.tasks[id(websocket)] = task
|
||||
|
||||
async def send_json(self, room_id: str, message: dict) -> None:
|
||||
await self.pubsub_client.send_json(room_id, message)
|
||||
|
||||
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||
self.rooms[room_id].remove(websocket)
|
||||
task = self.tasks.pop(id(websocket), None)
|
||||
if task:
|
||||
task.cancel()
|
||||
|
||||
if len(self.rooms[room_id]) == 0:
|
||||
del self.rooms[room_id]
|
||||
|
||||
Reference in New Issue
Block a user