mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
feat: postgresql migration and removal of sqlite in pytest (#546)
* feat: remove support of sqlite, 100% postgres * fix: more migration and make datetime timezone aware in postgres * fix: change how database is get, and use contextvar to have difference instance between different loops * test: properly use client fixture that handle lifetime/database connection * fix: add missing client fixture parameters to test functions This commit fixes NameError issues where test functions were trying to use the 'client' fixture but didn't have it as a parameter. The changes include: 1. Added 'client' parameter to test functions in: - test_transcripts_audio_download.py (6 functions including fixture) - test_transcripts_speaker.py (3 functions) - test_transcripts_upload.py (1 function) - test_transcripts_rtc_ws.py (2 functions + appserver fixture) 2. Resolved naming conflicts in test_transcripts_rtc_ws.py where both HTTP client and StreamClient were using variable name 'client'. StreamClient instances are now named 'stream_client' to avoid conflicts. 3. Added missing 'from reflector.app import app' import in rtc_ws tests. Background: Previously implemented contextvars solution with get_database() function resolves asyncio event loop conflicts in Celery tasks. The global client fixture was also created to replace manual AsyncClient instances, ensuring proper FastAPI application lifecycle management and database connections during tests. All tests now pass except for 2 pre-existing RTC WebSocket test failures related to asyncpg connection issues unrelated to these fixes. * fix: ensure task are correctly closed * fix: make separate event loop for the live server * fix: make default settings pointing at postgres * build: remove pytest-docker deps out of dev, just tests group
This commit is contained in:
@@ -15,7 +15,7 @@ from sqlalchemy import Enum
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
@@ -41,7 +41,7 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Column("status", sqlalchemy.String),
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Float),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
|
||||
sqlalchemy.Column("title", sqlalchemy.String),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
||||
@@ -421,7 +421,7 @@ class TranscriptController:
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
results = await get_database().fetch_all(query)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||
@@ -431,7 +431,7 @@ class TranscriptController:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
@@ -445,7 +445,7 @@ class TranscriptController:
|
||||
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
||||
if "user_id" in kwargs:
|
||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
@@ -463,7 +463,7 @@ class TranscriptController:
|
||||
if order_by.startswith("-"):
|
||||
field = field.desc()
|
||||
query = query.order_by(field)
|
||||
results = await database.fetch_all(query)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [Transcript(**result) for result in results]
|
||||
|
||||
async def get_by_id_for_http(
|
||||
@@ -481,7 +481,7 @@ class TranscriptController:
|
||||
to determine if the user can access the transcript.
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await database.fetch_one(query)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
@@ -534,7 +534,7 @@ class TranscriptController:
|
||||
room_id=room_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
return transcript
|
||||
|
||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||
@@ -553,7 +553,7 @@ class TranscriptController:
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
.values(**values)
|
||||
)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
if mutate:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
@@ -595,21 +595,21 @@ class TranscriptController:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
|
||||
async def remove_by_recording_id(self, recording_id: str):
|
||||
"""
|
||||
Remove a transcript by recording_id
|
||||
"""
|
||||
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
||||
await database.execute(query)
|
||||
await get_database().execute(query)
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""
|
||||
A context manager for database transaction
|
||||
"""
|
||||
async with database.transaction(isolation="serializable"):
|
||||
async with get_database().transaction(isolation="serializable"):
|
||||
yield
|
||||
|
||||
async def append_event(
|
||||
|
||||
Reference in New Issue
Block a user