feat: postgresql migration and removal of sqlite in pytest (#546)

* feat: remove support of sqlite, 100% postgres

* fix: more migration and make datetime timezone aware in postgres

* fix: change how database is get, and use contextvar to have difference instance between different loops

* test: properly use client fixture that handle lifetime/database connection

* fix: add missing client fixture parameters to test functions

This commit fixes NameError issues where test functions were trying to use
the 'client' fixture but didn't have it as a parameter. The changes include:

1. Added 'client' parameter to test functions in:
   - test_transcripts_audio_download.py (6 functions including fixture)
   - test_transcripts_speaker.py (3 functions)
   - test_transcripts_upload.py (1 function)
   - test_transcripts_rtc_ws.py (2 functions + appserver fixture)

2. Resolved naming conflicts in test_transcripts_rtc_ws.py where both HTTP
   client and StreamClient were using variable name 'client'. StreamClient
   instances are now named 'stream_client' to avoid conflicts.

3. Added missing 'from reflector.app import app' import in rtc_ws tests.

Background: Previously implemented contextvars solution with get_database()
function resolves asyncio event loop conflicts in Celery tasks. The global
client fixture was also created to replace manual AsyncClient instances,
ensuring proper FastAPI application lifecycle management and database
connections during tests.

All tests now pass except for 2 pre-existing RTC WebSocket test failures
related to asyncpg connection issues unrelated to these fixes.

* fix: ensure task are correctly closed

* fix: make separate event loop for the live server

* fix: make default settings pointing at postgres

* build: remove pytest-docker deps out of dev, just tests group
This commit is contained in:
2025-08-14 11:40:52 -06:00
committed by GitHub
parent 6fb5cb21c2
commit 9eab952c63
41 changed files with 2570 additions and 2287 deletions

View File

@@ -17,10 +17,40 @@ on:
jobs: jobs:
test-migrations: test-migrations:
runs-on: ubuntu-latest runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: reflector
POSTGRES_PASSWORD: reflector
POSTGRES_DB: reflector
ports:
- 5432:5432
options: >-
--health-cmd pg_isready -h 127.0.0.1 -p 5432
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
DATABASE_URL: postgresql://reflector:reflector@localhost:5432/reflector
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install PostgreSQL client
run: sudo apt-get update && sudo apt-get install -y postgresql-client | cat
- name: Wait for Postgres
run: |
for i in {1..30}; do
if pg_isready -h localhost -p 5432; then
echo "Postgres is ready"
break
fi
echo "Waiting for Postgres... ($i)" && sleep 1
done
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v3 uses: astral-sh/setup-uv@v3
with: with:

View File

@@ -32,7 +32,7 @@ def upgrade() -> None:
sa.Column("user_id", sa.String(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.Column("room_id", sa.String(), nullable=True), sa.Column("room_id", sa.String(), nullable=True),
sa.Column( sa.Column(
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False "is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
), ),
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False), sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
sa.Column( sa.Column(
@@ -53,12 +53,15 @@ def upgrade() -> None:
sa.Column("user_id", sa.String(), nullable=False), sa.Column("user_id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column( sa.Column(
"zulip_auto_post", sa.Boolean(), server_default=sa.text("0"), nullable=False "zulip_auto_post",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
), ),
sa.Column("zulip_stream", sa.String(), nullable=True), sa.Column("zulip_stream", sa.String(), nullable=True),
sa.Column("zulip_topic", sa.String(), nullable=True), sa.Column("zulip_topic", sa.String(), nullable=True),
sa.Column( sa.Column(
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False "is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
), ),
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False), sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
sa.Column( sa.Column(

View File

@@ -20,11 +20,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
sourcekind_enum = sa.Enum("room", "live", "file", name="sourcekind")
sourcekind_enum.create(op.get_bind())
op.add_column( op.add_column(
"transcript", "transcript",
sa.Column( sa.Column(
"source_kind", "source_kind",
sa.Enum("ROOM", "LIVE", "FILE", name="sourcekind"), sourcekind_enum,
nullable=True, nullable=True,
), ),
) )
@@ -43,6 +46,8 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "source_kind") op.drop_column("transcript", "source_kind")
sourcekind_enum = sa.Enum(name="sourcekind")
sourcekind_enum.drop(op.get_bind())
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -22,7 +22,7 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.execute( op.execute(
"UPDATE transcript SET events = " "UPDATE transcript SET events = "
'REPLACE(events, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\');' 'REPLACE(events::text, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\')::json;'
) )
op.alter_column("transcript", "summary", new_column_name="long_summary") op.alter_column("transcript", "summary", new_column_name="long_summary")
op.add_column("transcript", sa.Column("title", sa.String(), nullable=True)) op.add_column("transcript", sa.Column("title", sa.String(), nullable=True))
@@ -34,7 +34,7 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.execute( op.execute(
"UPDATE transcript SET events = " "UPDATE transcript SET events = "
'REPLACE(events, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\');' 'REPLACE(events::text, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\')::json;'
) )
with op.batch_alter_table("transcript", schema=None) as batch_op: with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column("long_summary", nullable=True, new_column_name="summary") batch_op.alter_column("long_summary", nullable=True, new_column_name="summary")

View File

@@ -0,0 +1,121 @@
"""datetime timezone
Revision ID: 9f5c78d352d6
Revises: 8120ebc75366
Create Date: 2025-08-13 19:18:27.113593
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "9f5c78d352d6"
down_revision: Union[str, None] = "8120ebc75366"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.alter_column(
"start_date",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
batch_op.alter_column(
"end_date",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
batch_op.alter_column(
"consent_timestamp",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.alter_column(
"recorded_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.alter_column(
"recorded_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
batch_op.alter_column(
"consent_timestamp",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.alter_column(
"end_date",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
batch_op.alter_column(
"start_date",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
# ### end Alembic commands ###

View File

@@ -25,7 +25,7 @@ def upgrade() -> None:
sa.Column( sa.Column(
"is_shared", "is_shared",
sa.Boolean(), sa.Boolean(),
server_default=sa.text("0"), server_default=sa.text("false"),
nullable=False, nullable=False,
), ),
) )

View File

@@ -23,7 +23,10 @@ def upgrade() -> None:
with op.batch_alter_table("meeting", schema=None) as batch_op: with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column( sa.Column(
"is_active", sa.Boolean(), server_default=sa.text("1"), nullable=False "is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
) )
) )

View File

@@ -23,7 +23,7 @@ def upgrade() -> None:
op.add_column( op.add_column(
"transcript", "transcript",
sa.Column( sa.Column(
"reviewed", sa.Boolean(), server_default=sa.text("0"), nullable=False "reviewed", sa.Boolean(), server_default=sa.text("false"), nullable=False
), ),
) )
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -57,6 +57,8 @@ tests = [
"httpx-ws>=0.4.1", "httpx-ws>=0.4.1",
"pytest-httpx>=0.23.1", "pytest-httpx>=0.23.1",
"pytest-celery>=0.0.0", "pytest-celery>=0.0.0",
"pytest-docker>=3.2.3",
"asgi-lifespan>=2.1.0",
] ]
aws = ["aioboto3>=11.2.0"] aws = ["aioboto3>=11.2.0"]
evaluation = [ evaluation = [
@@ -86,7 +88,7 @@ source = ["reflector"]
[tool.pytest_env] [tool.pytest_env]
ENVIRONMENT = "pytest" ENVIRONMENT = "pytest"
DATABASE_URL = "sqlite:///test.sqlite" DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v" addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"

View File

@@ -1,12 +1,28 @@
import contextvars
from typing import Optional
import databases import databases
import sqlalchemy import sqlalchemy
from reflector.events import subscribers_shutdown, subscribers_startup from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
contextvars.ContextVar("database", default=None)
)
def get_database() -> databases.Database:
"""Get database instance for current asyncio context"""
db = _database_context.get()
if db is None:
db = databases.Database(settings.DATABASE_URL)
_database_context.set(db)
return db
# import models # import models
import reflector.db.meetings # noqa import reflector.db.meetings # noqa
import reflector.db.recordings # noqa import reflector.db.recordings # noqa
@@ -14,16 +30,18 @@ import reflector.db.rooms # noqa
import reflector.db.transcripts # noqa import reflector.db.transcripts # noqa
kwargs = {} kwargs = {}
if "sqlite" in settings.DATABASE_URL: if "postgres" not in settings.DATABASE_URL:
kwargs["connect_args"] = {"check_same_thread": False} raise Exception("Only postgres database is supported in reflector")
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs) engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
@subscribers_startup.append @subscribers_startup.append
async def database_connect(_): async def database_connect(_):
database = get_database()
await database.connect() await database.connect()
@subscribers_shutdown.append @subscribers_shutdown.append
async def database_disconnect(_): async def database_disconnect(_):
database = get_database()
await database.disconnect() await database.disconnect()

View File

@@ -5,7 +5,7 @@ import sqlalchemy as sa
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from reflector.db import database, metadata from reflector.db import get_database, metadata
from reflector.db.rooms import Room from reflector.db.rooms import Room
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
@@ -16,8 +16,8 @@ meetings = sa.Table(
sa.Column("room_name", sa.String), sa.Column("room_name", sa.String),
sa.Column("room_url", sa.String), sa.Column("room_url", sa.String),
sa.Column("host_room_url", sa.String), sa.Column("host_room_url", sa.String),
sa.Column("start_date", sa.DateTime), sa.Column("start_date", sa.DateTime(timezone=True)),
sa.Column("end_date", sa.DateTime), sa.Column("end_date", sa.DateTime(timezone=True)),
sa.Column("user_id", sa.String), sa.Column("user_id", sa.String),
sa.Column("room_id", sa.String), sa.Column("room_id", sa.String),
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()), sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
@@ -42,6 +42,12 @@ meetings = sa.Table(
server_default=sa.true(), server_default=sa.true(),
), ),
sa.Index("idx_meeting_room_id", "room_id"), sa.Index("idx_meeting_room_id", "room_id"),
sa.Index(
"idx_one_active_meeting_per_room",
"room_id",
unique=True,
postgresql_where=sa.text("is_active = true"),
),
) )
meeting_consent = sa.Table( meeting_consent = sa.Table(
@@ -51,7 +57,7 @@ meeting_consent = sa.Table(
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False), sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
sa.Column("user_id", sa.String), sa.Column("user_id", sa.String),
sa.Column("consent_given", sa.Boolean, nullable=False), sa.Column("consent_given", sa.Boolean, nullable=False),
sa.Column("consent_timestamp", sa.DateTime, nullable=False), sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
) )
@@ -111,7 +117,7 @@ class MeetingController:
recording_trigger=room.recording_trigger, recording_trigger=room.recording_trigger,
) )
query = meetings.insert().values(**meeting.model_dump()) query = meetings.insert().values(**meeting.model_dump())
await database.execute(query) await get_database().execute(query)
return meeting return meeting
async def get_all_active(self) -> list[Meeting]: async def get_all_active(self) -> list[Meeting]:
@@ -119,7 +125,7 @@ class MeetingController:
Get active meetings. Get active meetings.
""" """
query = meetings.select().where(meetings.c.is_active) query = meetings.select().where(meetings.c.is_active)
return await database.fetch_all(query) return await get_database().fetch_all(query)
async def get_by_room_name( async def get_by_room_name(
self, self,
@@ -129,7 +135,7 @@ class MeetingController:
Get a meeting by room name. Get a meeting by room name.
""" """
query = meetings.select().where(meetings.c.room_name == room_name) query = meetings.select().where(meetings.c.room_name == room_name)
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
return None return None
@@ -151,7 +157,7 @@ class MeetingController:
) )
.order_by(end_date.desc()) .order_by(end_date.desc())
) )
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
return None return None
@@ -162,7 +168,7 @@ class MeetingController:
Get a meeting by id Get a meeting by id
""" """
query = meetings.select().where(meetings.c.id == meeting_id) query = meetings.select().where(meetings.c.id == meeting_id)
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
return None return None
return Meeting(**result) return Meeting(**result)
@@ -174,7 +180,7 @@ class MeetingController:
If not found, it will raise a 404 error. If not found, it will raise a 404 error.
""" """
query = meetings.select().where(meetings.c.id == meeting_id) query = meetings.select().where(meetings.c.id == meeting_id)
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")
@@ -186,7 +192,7 @@ class MeetingController:
async def update_meeting(self, meeting_id: str, **kwargs): async def update_meeting(self, meeting_id: str, **kwargs):
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs) query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
await database.execute(query) await get_database().execute(query)
class MeetingConsentController: class MeetingConsentController:
@@ -194,7 +200,7 @@ class MeetingConsentController:
query = meeting_consent.select().where( query = meeting_consent.select().where(
meeting_consent.c.meeting_id == meeting_id meeting_consent.c.meeting_id == meeting_id
) )
results = await database.fetch_all(query) results = await get_database().fetch_all(query)
return [MeetingConsent(**result) for result in results] return [MeetingConsent(**result) for result in results]
async def get_by_meeting_and_user( async def get_by_meeting_and_user(
@@ -205,7 +211,7 @@ class MeetingConsentController:
meeting_consent.c.meeting_id == meeting_id, meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.user_id == user_id, meeting_consent.c.user_id == user_id,
) )
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if result is None: if result is None:
return None return None
return MeetingConsent(**result) if result else None return MeetingConsent(**result) if result else None
@@ -227,14 +233,14 @@ class MeetingConsentController:
consent_timestamp=consent.consent_timestamp, consent_timestamp=consent.consent_timestamp,
) )
) )
await database.execute(query) await get_database().execute(query)
existing.consent_given = consent.consent_given existing.consent_given = consent.consent_given
existing.consent_timestamp = consent.consent_timestamp existing.consent_timestamp = consent.consent_timestamp
return existing return existing
query = meeting_consent.insert().values(**consent.model_dump()) query = meeting_consent.insert().values(**consent.model_dump())
await database.execute(query) await get_database().execute(query)
return consent return consent
async def has_any_denial(self, meeting_id: str) -> bool: async def has_any_denial(self, meeting_id: str) -> bool:
@@ -243,7 +249,7 @@ class MeetingConsentController:
meeting_consent.c.meeting_id == meeting_id, meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.consent_given.is_(False), meeting_consent.c.consent_given.is_(False),
) )
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
return result is not None return result is not None

View File

@@ -4,7 +4,7 @@ from typing import Literal
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from reflector.db import database, metadata from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
recordings = sa.Table( recordings = sa.Table(
@@ -13,7 +13,7 @@ recordings = sa.Table(
sa.Column("id", sa.String, primary_key=True), sa.Column("id", sa.String, primary_key=True),
sa.Column("bucket_name", sa.String, nullable=False), sa.Column("bucket_name", sa.String, nullable=False),
sa.Column("object_key", sa.String, nullable=False), sa.Column("object_key", sa.String, nullable=False),
sa.Column("recorded_at", sa.DateTime, nullable=False), sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
sa.Column( sa.Column(
"status", "status",
sa.String, sa.String,
@@ -37,12 +37,12 @@ class Recording(BaseModel):
class RecordingController: class RecordingController:
async def create(self, recording: Recording): async def create(self, recording: Recording):
query = recordings.insert().values(**recording.model_dump()) query = recordings.insert().values(**recording.model_dump())
await database.execute(query) await get_database().execute(query)
return recording return recording
async def get_by_id(self, id: str) -> Recording: async def get_by_id(self, id: str) -> Recording:
query = recordings.select().where(recordings.c.id == id) query = recordings.select().where(recordings.c.id == id)
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
return Recording(**result) if result else None return Recording(**result) if result else None
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording: async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
@@ -50,7 +50,7 @@ class RecordingController:
recordings.c.bucket_name == bucket_name, recordings.c.bucket_name == bucket_name,
recordings.c.object_key == object_key, recordings.c.object_key == object_key,
) )
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
return Recording(**result) if result else None return Recording(**result) if result else None

View File

@@ -1,4 +1,4 @@
from datetime import datetime from datetime import datetime, timezone
from sqlite3 import IntegrityError from sqlite3 import IntegrityError
from typing import Literal from typing import Literal
@@ -7,7 +7,7 @@ from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.sql import false, or_ from sqlalchemy.sql import false, or_
from reflector.db import database, metadata from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
rooms = sqlalchemy.Table( rooms = sqlalchemy.Table(
@@ -16,7 +16,7 @@ rooms = sqlalchemy.Table(
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True), sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False), sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
sqlalchemy.Column("created_at", sqlalchemy.DateTime, nullable=False), sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
sqlalchemy.Column( sqlalchemy.Column(
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false() "zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
), ),
@@ -48,7 +48,7 @@ class Room(BaseModel):
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
name: str name: str
user_id: str user_id: str
created_at: datetime = Field(default_factory=datetime.utcnow) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
zulip_auto_post: bool = False zulip_auto_post: bool = False
zulip_stream: str = "" zulip_stream: str = ""
zulip_topic: str = "" zulip_topic: str = ""
@@ -92,7 +92,7 @@ class RoomController:
if return_query: if return_query:
return query return query
results = await database.fetch_all(query) results = await get_database().fetch_all(query)
return results return results
async def add( async def add(
@@ -125,7 +125,7 @@ class RoomController:
) )
query = rooms.insert().values(**room.model_dump()) query = rooms.insert().values(**room.model_dump())
try: try:
await database.execute(query) await get_database().execute(query)
except IntegrityError: except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique") raise HTTPException(status_code=400, detail="Room name is not unique")
return room return room
@@ -136,7 +136,7 @@ class RoomController:
""" """
query = rooms.update().where(rooms.c.id == room.id).values(**values) query = rooms.update().where(rooms.c.id == room.id).values(**values)
try: try:
await database.execute(query) await get_database().execute(query)
except IntegrityError: except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique") raise HTTPException(status_code=400, detail="Room name is not unique")
@@ -151,7 +151,7 @@ class RoomController:
query = rooms.select().where(rooms.c.id == room_id) query = rooms.select().where(rooms.c.id == room_id)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(rooms.c.user_id == kwargs["user_id"]) query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
return None return None
return Room(**result) return Room(**result)
@@ -163,7 +163,7 @@ class RoomController:
query = rooms.select().where(rooms.c.name == room_name) query = rooms.select().where(rooms.c.name == room_name)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(rooms.c.user_id == kwargs["user_id"]) query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
return None return None
return Room(**result) return Room(**result)
@@ -175,7 +175,7 @@ class RoomController:
If not found, it will raise a 404 error. If not found, it will raise a 404 error.
""" """
query = rooms.select().where(rooms.c.id == meeting_id) query = rooms.select().where(rooms.c.id == meeting_id)
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -197,7 +197,7 @@ class RoomController:
if user_id is not None and room.user_id != user_id: if user_id is not None and room.user_id != user_id:
return return
query = rooms.delete().where(rooms.c.id == room_id) query = rooms.delete().where(rooms.c.id == room_id)
await database.execute(query) await get_database().execute(query)
rooms_controller = RoomController() rooms_controller = RoomController()

View File

@@ -9,7 +9,7 @@ import sqlalchemy
import webvtt import webvtt
from pydantic import BaseModel, Field, constr, field_serializer from pydantic import BaseModel, Field, constr, field_serializer
from reflector.db import database from reflector.db import get_database
from reflector.db.transcripts import SourceKind, transcripts from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql from reflector.db.utils import is_postgresql
@@ -207,12 +207,12 @@ class SearchController:
.limit(params.limit) .limit(params.limit)
.offset(params.offset) .offset(params.offset)
) )
rs = await database.fetch_all(query) rs = await get_database().fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from( count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results") base_query.alias("search_results")
) )
total = await database.fetch_val(count_query) total = await get_database().fetch_val(count_query)
def _process_result(r) -> SearchResult: def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r) r_dict: Dict[str, Any] = dict(r)

View File

@@ -15,7 +15,7 @@ from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_ from sqlalchemy.sql import false, or_
from reflector.db import database, metadata from reflector.db import get_database, metadata
from reflector.db.rooms import rooms from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql from reflector.db.utils import is_postgresql
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
@@ -41,7 +41,7 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("status", sqlalchemy.String), sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean), sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Float), sqlalchemy.Column("duration", sqlalchemy.Float),
sqlalchemy.Column("created_at", sqlalchemy.DateTime), sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("title", sqlalchemy.String), sqlalchemy.Column("title", sqlalchemy.String),
sqlalchemy.Column("short_summary", sqlalchemy.String), sqlalchemy.Column("short_summary", sqlalchemy.String),
sqlalchemy.Column("long_summary", sqlalchemy.String), sqlalchemy.Column("long_summary", sqlalchemy.String),
@@ -421,7 +421,7 @@ class TranscriptController:
if return_query: if return_query:
return query return query
results = await database.fetch_all(query) results = await get_database().fetch_all(query)
return results return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
@@ -431,7 +431,7 @@ class TranscriptController:
query = transcripts.select().where(transcripts.c.id == transcript_id) query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"]) query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
return None return None
return Transcript(**result) return Transcript(**result)
@@ -445,7 +445,7 @@ class TranscriptController:
query = transcripts.select().where(transcripts.c.recording_id == recording_id) query = transcripts.select().where(transcripts.c.recording_id == recording_id)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"]) query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
return None return None
return Transcript(**result) return Transcript(**result)
@@ -463,7 +463,7 @@ class TranscriptController:
if order_by.startswith("-"): if order_by.startswith("-"):
field = field.desc() field = field.desc()
query = query.order_by(field) query = query.order_by(field)
results = await database.fetch_all(query) results = await get_database().fetch_all(query)
return [Transcript(**result) for result in results] return [Transcript(**result) for result in results]
async def get_by_id_for_http( async def get_by_id_for_http(
@@ -481,7 +481,7 @@ class TranscriptController:
to determine if the user can access the transcript. to determine if the user can access the transcript.
""" """
query = transcripts.select().where(transcripts.c.id == transcript_id) query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query) result = await get_database().fetch_one(query)
if not result: if not result:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
@@ -534,7 +534,7 @@ class TranscriptController:
room_id=room_id, room_id=room_id,
) )
query = transcripts.insert().values(**transcript.model_dump()) query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query) await get_database().execute(query)
return transcript return transcript
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates. # TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
@@ -553,7 +553,7 @@ class TranscriptController:
.where(transcripts.c.id == transcript.id) .where(transcripts.c.id == transcript.id)
.values(**values) .values(**values)
) )
await database.execute(query) await get_database().execute(query)
if mutate: if mutate:
for key, value in values.items(): for key, value in values.items():
setattr(transcript, key, value) setattr(transcript, key, value)
@@ -595,21 +595,21 @@ class TranscriptController:
return return
transcript.unlink() transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id) query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query) await get_database().execute(query)
async def remove_by_recording_id(self, recording_id: str): async def remove_by_recording_id(self, recording_id: str):
""" """
Remove a transcript by recording_id Remove a transcript by recording_id
""" """
query = transcripts.delete().where(transcripts.c.recording_id == recording_id) query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
await database.execute(query) await get_database().execute(query)
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
""" """
A context manager for database transaction A context manager for database transaction
""" """
async with database.transaction(isolation="serializable"): async with get_database().transaction(isolation="serializable"):
yield yield
async def append_event( async def append_event(

View File

@@ -1,7 +1,9 @@
"""Database utility functions.""" """Database utility functions."""
from reflector.db import database from reflector.db import get_database
def is_postgresql() -> bool: def is_postgresql() -> bool:
return database.url.scheme and database.url.scheme.startswith('postgresql') return get_database().url.scheme and get_database().url.scheme.startswith(
"postgresql"
)

View File

@@ -22,7 +22,7 @@ from celery import chord, current_task, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from structlog import BoundLogger as Logger from structlog import BoundLogger as Logger
from reflector.db import database from reflector.db import get_database
from reflector.db.meetings import meeting_consent_controller, meetings_controller from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -72,6 +72,7 @@ def asynctask(f):
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
async def run_with_db(): async def run_with_db():
database = get_database()
await database.connect() await database.connect()
try: try:
return await f(*args, **kwargs) return await f(*args, **kwargs)

View File

@@ -6,7 +6,7 @@ This script is used to generate a summary of a meeting notes transcript.
import asyncio import asyncio
import sys import sys
from datetime import datetime from datetime import datetime, timezone
from enum import Enum from enum import Enum
from textwrap import dedent from textwrap import dedent
from typing import Type, TypeVar from typing import Type, TypeVar
@@ -474,7 +474,7 @@ if __name__ == "__main__":
if args.save: if args.save:
# write the summary to a file, on the format summary-<iso date>.md # write the summary to a file, on the format summary-<iso date>.md
filename = f"summary-{datetime.now().isoformat()}.md" filename = f"summary-{datetime.now(timezone.utc).isoformat()}.md"
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
f.write(sm.as_markdown()) f.write(sm.as_markdown())

View File

@@ -14,7 +14,9 @@ class Settings(BaseSettings):
CORS_ALLOW_CREDENTIALS: bool = False CORS_ALLOW_CREDENTIALS: bool = False
# Database # Database
DATABASE_URL: str = "sqlite:///./reflector.sqlite3" DATABASE_URL: str = (
"postgresql+asyncpg://reflector:reflector@localhost:5432/reflector"
)
# local data directory # local data directory
DATA_DIR: str = "./data" DATA_DIR: str = "./data"

View File

@@ -9,8 +9,9 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve() filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}" settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import database, transcripts from reflector.db import get_database, transcripts
database = get_database()
await database.connect() await database.connect()
transcripts = await database.fetch_all(transcripts.select()) transcripts = await database.fetch_all(transcripts.select())
await database.disconnect() await database.disconnect()

View File

@@ -8,8 +8,9 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve() filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}" settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import database, transcripts from reflector.db import get_database, transcripts
database = get_database()
await database.connect() await database.connect()
transcripts = await database.fetch_all(transcripts.select()) transcripts = await database.fetch_all(transcripts.select())
await database.disconnect() await database.disconnect()

View File

@@ -155,7 +155,7 @@ async def process_audio_file_with_diarization(
# For Modal backend, we need to upload the file to S3 first # For Modal backend, we need to upload the file to S3 first
if diarization_backend == "modal": if diarization_backend == "modal":
from datetime import datetime from datetime import datetime, timezone
from reflector.storage import get_transcripts_storage from reflector.storage import get_transcripts_storage
from reflector.utils.s3_temp_file import S3TemporaryFile from reflector.utils.s3_temp_file import S3TemporaryFile
@@ -163,7 +163,7 @@ async def process_audio_file_with_diarization(
storage = get_transcripts_storage() storage = get_transcripts_storage()
# Generate a unique filename in evaluation folder # Generate a unique filename in evaluation folder
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav" audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
# Use context manager for automatic cleanup # Use context manager for automatic cleanup

View File

@@ -1,4 +1,4 @@
from datetime import datetime from datetime import datetime, timezone
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
@@ -35,7 +35,7 @@ async def meeting_audio_consent(
meeting_id=meeting_id, meeting_id=meeting_id,
user_id=user_id, user_id=user_id,
consent_given=request.consent_given, consent_given=request.consent_given,
consent_timestamp=datetime.utcnow(), consent_timestamp=datetime.now(timezone.utc),
) )
updated_consent = await meeting_consent_controller.upsert(consent) updated_consent = await meeting_consent_controller.upsert(consent)

View File

@@ -1,16 +1,16 @@
import logging import logging
import sqlite3 import sqlite3
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Annotated, Literal, Optional from typing import Annotated, Literal, Optional
import asyncpg.exceptions import asyncpg.exceptions
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate from fastapi_pagination.ext.databases import apaginate
from pydantic import BaseModel from pydantic import BaseModel
import reflector.auth as auth import reflector.auth as auth
from reflector.db import database from reflector.db import get_database
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.settings import settings from reflector.settings import settings
@@ -21,6 +21,14 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
class Room(BaseModel): class Room(BaseModel):
id: str id: str
name: str name: str
@@ -83,8 +91,8 @@ async def rooms_list(
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await paginate( return await apaginate(
database, get_database(),
await rooms_controller.get_all( await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True user_id=user_id, order_by="-created_at", return_query=True
), ),
@@ -150,7 +158,7 @@ async def rooms_create_meeting(
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.utcnow() current_time = datetime.now(timezone.utc)
meeting = await meetings_controller.get_active(room=room, current_time=current_time) meeting = await meetings_controller.get_active(room=room, current_time=current_time)
if meeting is None: if meeting is None:
@@ -166,8 +174,8 @@ async def rooms_create_meeting(
room_name=whereby_meeting["roomName"], room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"], room_url=whereby_meeting["roomUrl"],
host_room_url=whereby_meeting["hostRoomUrl"], host_room_url=whereby_meeting["hostRoomUrl"],
start_date=datetime.fromisoformat(whereby_meeting["startDate"]), start_date=parse_datetime_with_timezone(whereby_meeting["startDate"]),
end_date=datetime.fromisoformat(whereby_meeting["endDate"]), end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
user_id=user_id, user_id=user_id,
room=room, room=room,
) )

View File

@@ -3,12 +3,12 @@ from typing import Annotated, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate from fastapi_pagination.ext.databases import apaginate
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field, field_serializer from pydantic import BaseModel, Field, field_serializer
import reflector.auth as auth import reflector.auth as auth
from reflector.db import database from reflector.db import get_database
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.search import ( from reflector.db.search import (
@@ -48,7 +48,7 @@ DOWNLOAD_EXPIRE_MINUTES = 60
def create_access_token(data: dict, expires_delta: timedelta): def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy() to_encode = data.copy()
expire = datetime.utcnow() + expires_delta expire = datetime.now(timezone.utc) + expires_delta
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
@@ -141,8 +141,8 @@ async def transcripts_list(
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await paginate( return await apaginate(
database, get_database(),
await transcripts_controller.get_all( await transcripts_controller.get_all(
user_id=user_id, user_id=user_id,
source_kind=SourceKind(source_kind) if source_kind else None, source_kind=SourceKind(source_kind) if source_kind else None,

View File

@@ -21,6 +21,14 @@ from reflector.whereby import get_room_sessions
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
@shared_task @shared_task
def process_messages(): def process_messages():
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
@@ -69,7 +77,7 @@ async def process_recording(bucket_name: str, object_key: str):
# extract a guid and a datetime from the object key # extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}" room_name = f"/{object_key[:36]}"
recorded_at = datetime.fromisoformat(object_key[37:57]) recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(room_name) meeting = await meetings_controller.get_by_room_name(room_name)
room = await rooms_controller.get_by_id(meeting.room_id) room = await rooms_controller.get_by_id(meeting.room_id)

View File

@@ -62,6 +62,7 @@ class RedisPubSubManager:
class WebsocketManager: class WebsocketManager:
def __init__(self, pubsub_client: RedisPubSubManager = None): def __init__(self, pubsub_client: RedisPubSubManager = None):
self.rooms: dict = {} self.rooms: dict = {}
self.tasks: dict = {}
self.pubsub_client = pubsub_client self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
@@ -74,13 +75,17 @@ class WebsocketManager:
await self.pubsub_client.connect() await self.pubsub_client.connect()
pubsub_subscriber = await self.pubsub_client.subscribe(room_id) pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber)) task = asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
self.tasks[id(websocket)] = task
async def send_json(self, room_id: str, message: dict) -> None: async def send_json(self, room_id: str, message: dict) -> None:
await self.pubsub_client.send_json(room_id, message) await self.pubsub_client.send_json(room_id, message)
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None: async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
self.rooms[room_id].remove(websocket) self.rooms[room_id].remove(websocket)
task = self.tasks.pop(id(websocket), None)
if task:
task.cancel()
if len(self.rooms[room_id]) == 0: if len(self.rooms[room_id]) == 0:
del self.rooms[room_id] del self.rooms[room_id]

View File

@@ -1,17 +1,63 @@
import os
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
# Pytest-docker configuration
@pytest.fixture(scope="session")
def docker_compose_file(pytestconfig):
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"dbname": "reflector_test",
"user": "test_user",
"password": "test_password",
}
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def setup_database(): async def setup_database(postgres_service):
from reflector.db import engine, metadata # noqa from reflector.db import engine, metadata, get_database # noqa
metadata.drop_all(bind=engine) metadata.drop_all(bind=engine)
metadata.create_all(bind=engine) metadata.create_all(bind=engine)
database = get_database()
try:
await database.connect()
yield yield
finally:
await database.disconnect()
@pytest.fixture @pytest.fixture
@@ -46,6 +92,20 @@ def dummy_processors():
) # noqa ) # noqa
@pytest.fixture
async def whisper_transcript():
from reflector.processors.audio_transcript_whisper import (
AudioTranscriptWhisperProcessor,
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = AudioTranscriptWhisperProcessor()
yield
@pytest.fixture @pytest.fixture
async def dummy_transcript(): async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript import AudioTranscriptProcessor
@@ -181,6 +241,16 @@ def celery_includes():
return ["reflector.pipelines.main_live_pipeline"] return ["reflector.pipelines.main_live_pipeline"]
@pytest.fixture
async def client():
from httpx import AsyncClient
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
yield ac
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def fake_mp3_upload(): def fake_mp3_upload():
with patch( with patch(
@@ -191,13 +261,10 @@ def fake_mp3_upload():
@pytest.fixture @pytest.fixture
async def fake_transcript_with_topics(tmpdir): async def fake_transcript_with_topics(tmpdir, client):
import shutil import shutil
from pathlib import Path from pathlib import Path
from httpx import AsyncClient
from reflector.app import app
from reflector.db.transcripts import TranscriptTopic from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word from reflector.processors.types import Word
from reflector.settings import settings from reflector.settings import settings
@@ -206,8 +273,7 @@ async def fake_transcript_with_topics(tmpdir):
settings.DATA_DIR = Path(tmpdir) settings.DATA_DIR = Path(tmpdir)
# create a transcript # create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1") response = await client.post("/transcripts", json={"name": "Test audio download"})
response = await ac.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200 assert response.status_code == 200
tid = response.json()["id"] tid = response.json()["id"]

View File

@@ -0,0 +1,13 @@
version: '3.8'
services:
postgres_test:
image: postgres:15
environment:
POSTGRES_DB: reflector_test
POSTGRES_USER: test_user
POSTGRES_PASSWORD: test_password
ports:
- "15432:5432"
command: postgres -c fsync=off -c synchronous_commit=off -c full_page_writes=off
tmpfs:
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g

View File

@@ -1,22 +1,18 @@
"""Tests for full-text search functionality.""" """Tests for full-text search functionality."""
import json import json
from datetime import datetime from datetime import datetime, timezone
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from reflector.db import database from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts from reflector.db.transcripts import transcripts
from reflector.db.utils import is_postgresql
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_postgresql_only(): async def test_search_postgresql_only():
await database.connect()
try:
params = SearchParameters(query_text="any query here") params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params) results, total = await search_controller.search_transcripts(params)
assert results == [] assert results == []
@@ -35,15 +31,9 @@ async def test_search_postgresql_only():
except ValidationError: except ValidationError:
pass # Expected pass # Expected
finally:
await database.disconnect()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_input_validation(): async def test_search_input_validation():
await database.connect()
try:
try: try:
SearchParameters(query_text="") SearchParameters(query_text="")
assert False, "Should have raised ValidationError" assert False, "Should have raised ValidationError"
@@ -56,27 +46,17 @@ async def test_search_input_validation():
assert False, "Should have raised ValidationError" assert False, "Should have raised ValidationError"
except ValidationError: except ValidationError:
pass # Expected pass # Expected
finally:
await database.disconnect()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgresql_search_with_data(): async def test_postgresql_search_with_data():
"""Test full-text search with actual data in PostgreSQL.
Example how to run: DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector_test uv run pytest tests/test_search.py::test_postgresql_search_with_data -v -p no:env
"""
# Skip if not PostgreSQL
if not is_postgresql():
pytest.skip("Test requires PostgreSQL. Set DATABASE_URL=postgresql://...")
await database.connect()
# collision is improbable # collision is improbable
test_id = "test-search-e2e-7f3a9b2c" test_id = "test-search-e2e-7f3a9b2c"
try: try:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id)) await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = { test_data = {
"id": test_id, "id": test_id,
@@ -85,7 +65,7 @@ async def test_postgresql_search_with_data():
"status": "completed", "status": "completed",
"locked": False, "locked": False,
"duration": 1800.0, "duration": 1800.0,
"created_at": datetime.now(), "created_at": datetime.now(timezone.utc),
"short_summary": "Team discussed search implementation", "short_summary": "Team discussed search implementation",
"long_summary": "The engineering team met to plan the search feature", "long_summary": "The engineering team met to plan the search feature",
"topics": json.dumps([]), "topics": json.dumps([]),
@@ -112,7 +92,7 @@ The search feature should support complex queries with ranking.
We need to implement PostgreSQL tsvector for better performance.""", We need to implement PostgreSQL tsvector for better performance.""",
} }
await database.execute(transcripts.insert().values(**test_data)) await get_database().execute(transcripts.insert().values(**test_data))
# Test 1: Search for a word in title # Test 1: Search for a word in title
params = SearchParameters(query_text="planning") params = SearchParameters(query_text="planning")
@@ -141,7 +121,6 @@ We need to implement PostgreSQL tsvector for better performance.""",
assert test_result.title == "Engineering Planning Meeting Q4 2024" assert test_result.title == "Engineering Planning Meeting Q4 2024"
assert test_result.status == "completed" assert test_result.status == "completed"
assert test_result.duration == 1800.0 assert test_result.duration == 1800.0
assert test_result.source_kind == "room"
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1" assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
# Test 5: Search with OR operator # Test 5: Search with OR operator
@@ -159,5 +138,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
assert found, "Should find test transcript by exact phrase" assert found, "Should find test transcript by exact phrase"
finally: finally:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id)) await get_database().execute(
await database.disconnect() transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()

View File

@@ -1,15 +1,11 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import pytest import pytest
from httpx import AsyncClient
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_create(): async def test_transcript_create(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test" assert response.json()["name"] == "test"
assert response.json()["status"] == "idle" assert response.json()["status"] == "idle"
@@ -23,71 +19,62 @@ async def test_transcript_create():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_get_update_name(): async def test_transcript_get_update_name(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test" assert response.json()["name"] == "test"
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test" assert response.json()["name"] == "test"
response = await ac.patch(f"/transcripts/{tid}", json={"name": "test2"}) response = await client.patch(f"/transcripts/{tid}", json={"name": "test2"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test2" assert response.json()["name"] == "test2"
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test2" assert response.json()["name"] == "test2"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_get_update_locked(): async def test_transcript_get_update_locked(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["locked"] is False assert response.json()["locked"] is False
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["locked"] is False assert response.json()["locked"] is False
response = await ac.patch(f"/transcripts/{tid}", json={"locked": True}) response = await client.patch(f"/transcripts/{tid}", json={"locked": True})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["locked"] is True assert response.json()["locked"] is True
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["locked"] is True assert response.json()["locked"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_get_update_summary(): async def test_transcript_get_update_summary(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["long_summary"] is None assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None assert response.json()["short_summary"] is None
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["long_summary"] is None assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None assert response.json()["short_summary"] is None
response = await ac.patch( response = await client.patch(
f"/transcripts/{tid}", f"/transcripts/{tid}",
json={"long_summary": "test_long", "short_summary": "test_short"}, json={"long_summary": "test_long", "short_summary": "test_short"},
) )
@@ -95,52 +82,46 @@ async def test_transcript_get_update_summary():
assert response.json()["long_summary"] == "test_long" assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short" assert response.json()["short_summary"] == "test_short"
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["long_summary"] == "test_long" assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short" assert response.json()["short_summary"] == "test_short"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_get_update_title(): async def test_transcript_get_update_title(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["title"] is None assert response.json()["title"] is None
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["title"] is None assert response.json()["title"] is None
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"}) response = await client.patch(f"/transcripts/{tid}", json={"title": "test_title"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["title"] == "test_title" assert response.json()["title"] == "test_title"
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["title"] == "test_title" assert response.json()["title"] == "test_title"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcripts_list_anonymous(): async def test_transcripts_list_anonymous(client):
# XXX this test is a bit fragile, as it depends on the storage which # XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests # is shared between tests
from reflector.app import app
from reflector.settings import settings from reflector.settings import settings
async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await client.get("/transcripts")
response = await ac.get("/transcripts")
assert response.status_code == 401 assert response.status_code == 401
# if public mode, it should be allowed # if public mode, it should be allowed
try: try:
settings.PUBLIC_MODE = True settings.PUBLIC_MODE = True
async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await client.get("/transcripts")
response = await ac.get("/transcripts")
assert response.status_code == 200 assert response.status_code == 200
finally: finally:
settings.PUBLIC_MODE = False settings.PUBLIC_MODE = False
@@ -197,21 +178,19 @@ async def authenticated_client2():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcripts_list_authenticated(authenticated_client): async def test_transcripts_list_authenticated(authenticated_client, client):
# XXX this test is a bit fragile, as it depends on the storage which # XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests # is shared between tests
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await client.post("/transcripts", json={"name": "testxx1"})
response = await ac.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "testxx1" assert response.json()["name"] == "testxx1"
response = await ac.post("/transcripts", json={"name": "testxx2"}) response = await client.post("/transcripts", json={"name": "testxx2"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "testxx2" assert response.json()["name"] == "testxx2"
response = await ac.get("/transcripts") response = await client.get("/transcripts")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()["items"]) >= 2 assert len(response.json()["items"]) >= 2
names = [t["name"] for t in response.json()["items"]] names = [t["name"] for t in response.json()["items"]]
@@ -220,44 +199,38 @@ async def test_transcripts_list_authenticated(authenticated_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_delete(): async def test_transcript_delete(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "testdel1"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "testdel1" assert response.json()["name"] == "testdel1"
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.delete(f"/transcripts/{tid}") response = await client.delete(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_mark_reviewed(): async def test_transcript_mark_reviewed(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test" assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False assert response.json()["reviewed"] is False
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test" assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False assert response.json()["reviewed"] is False
response = await ac.patch(f"/transcripts/{tid}", json={"reviewed": True}) response = await client.patch(f"/transcripts/{tid}", json={"reviewed": True})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["reviewed"] is True assert response.json()["reviewed"] is True
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["reviewed"] is True assert response.json()["reviewed"] is True

View File

@@ -2,20 +2,17 @@ import shutil
from pathlib import Path from pathlib import Path
import pytest import pytest
from httpx import AsyncClient
@pytest.fixture @pytest.fixture
async def fake_transcript(tmpdir): async def fake_transcript(tmpdir, client):
from reflector.app import app
from reflector.settings import settings from reflector.settings import settings
from reflector.views.transcripts import transcripts_controller from reflector.views.transcripts import transcripts_controller
settings.DATA_DIR = Path(tmpdir) settings.DATA_DIR = Path(tmpdir)
# create a transcript # create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1") response = await client.post("/transcripts", json={"name": "Test audio download"})
response = await ac.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200 assert response.status_code == 200
tid = response.json()["id"] tid = response.json()["id"]
@@ -39,17 +36,17 @@ async def fake_transcript(tmpdir):
["/mp3", "audio/mpeg"], ["/mp3", "audio/mpeg"],
], ],
) )
async def test_transcript_audio_download(fake_transcript, url_suffix, content_type): async def test_transcript_audio_download(
from reflector.app import app fake_transcript, url_suffix, content_type, client
):
ac = AsyncClient(app=app, base_url="http://test/v1") response = await client.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == content_type assert response.headers["content-type"] == content_type
# test get 404 # test get 404
ac = AsyncClient(app=app, base_url="http://test/v1") response = await client.get(
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}") f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
)
assert response.status_code == 404 assert response.status_code == 404
@@ -61,18 +58,16 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
], ],
) )
async def test_transcript_audio_download_head( async def test_transcript_audio_download_head(
fake_transcript, url_suffix, content_type fake_transcript, url_suffix, content_type, client
): ):
from reflector.app import app response = await client.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == content_type assert response.headers["content-type"] == content_type
# test head 404 # test head 404
ac = AsyncClient(app=app, base_url="http://test/v1") response = await client.head(
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}") f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
)
assert response.status_code == 404 assert response.status_code == 404
@@ -84,12 +79,9 @@ async def test_transcript_audio_download_head(
], ],
) )
async def test_transcript_audio_download_range( async def test_transcript_audio_download_range(
fake_transcript, url_suffix, content_type fake_transcript, url_suffix, content_type, client
): ):
from reflector.app import app response = await client.get(
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}", f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=0-100"}, headers={"range": "bytes=0-100"},
) )
@@ -107,12 +99,9 @@ async def test_transcript_audio_download_range(
], ],
) )
async def test_transcript_audio_download_range_with_seek( async def test_transcript_audio_download_range_with_seek(
fake_transcript, url_suffix, content_type fake_transcript, url_suffix, content_type, client
): ):
from reflector.app import app response = await client.get(
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}", f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=100-"}, headers={"range": "bytes=100-"},
) )
@@ -122,13 +111,10 @@ async def test_transcript_audio_download_range_with_seek(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_delete_with_audio(fake_transcript): async def test_transcript_delete_with_audio(fake_transcript, client):
from reflector.app import app response = await client.delete(f"/transcripts/{fake_transcript.id}")
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.delete(f"/transcripts/{fake_transcript.id}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
response = await ac.get(f"/transcripts/{fake_transcript.id}") response = await client.get(f"/transcripts/{fake_transcript.id}")
assert response.status_code == 404 assert response.status_code == 404

View File

@@ -1,19 +1,15 @@
import pytest import pytest
from httpx import AsyncClient
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_participants(): async def test_transcript_participants(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["participants"] == [] assert response.json()["participants"] == []
# create a participant # create a participant
transcript_id = response.json()["id"] transcript_id = response.json()["id"]
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"} f"/transcripts/{transcript_id}/participants", json={"name": "test"}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -22,7 +18,7 @@ async def test_transcript_participants():
assert response.json()["name"] == "test" assert response.json()["name"] == "test"
# create another one with a speaker # create another one with a speaker
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1}, json={"name": "test2", "speaker": 1},
) )
@@ -32,28 +28,25 @@ async def test_transcript_participants():
assert response.json()["name"] == "test2" assert response.json()["name"] == "test2"
# get all participants via transcript # get all participants via transcript
response = await ac.get(f"/transcripts/{transcript_id}") response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()["participants"]) == 2 assert len(response.json()["participants"]) == 2
# get participants via participants endpoint # get participants via participants endpoint
response = await ac.get(f"/transcripts/{transcript_id}/participants") response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 2 assert len(response.json()) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_participants_same_speaker(): async def test_transcript_participants_same_speaker(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["participants"] == [] assert response.json()["participants"] == []
transcript_id = response.json()["id"] transcript_id = response.json()["id"]
# create a participant # create a participant
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1}, json={"name": "test", "speaker": 1},
) )
@@ -61,7 +54,7 @@ async def test_transcript_participants_same_speaker():
assert response.json()["speaker"] == 1 assert response.json()["speaker"] == 1
# create another one with the same speaker # create another one with the same speaker
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1}, json={"name": "test2", "speaker": 1},
) )
@@ -69,17 +62,14 @@ async def test_transcript_participants_same_speaker():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_participants_update_name(): async def test_transcript_participants_update_name(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["participants"] == [] assert response.json()["participants"] == []
transcript_id = response.json()["id"] transcript_id = response.json()["id"]
# create a participant # create a participant
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1}, json={"name": "test", "speaker": 1},
) )
@@ -88,7 +78,7 @@ async def test_transcript_participants_update_name():
# update the participant # update the participant
participant_id = response.json()["id"] participant_id = response.json()["id"]
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}", f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"}, json={"name": "test2"},
) )
@@ -96,31 +86,28 @@ async def test_transcript_participants_update_name():
assert response.json()["name"] == "test2" assert response.json()["name"] == "test2"
# verify the participant was updated # verify the participant was updated
response = await ac.get( response = await client.get(
f"/transcripts/{transcript_id}/participants/{participant_id}" f"/transcripts/{transcript_id}/participants/{participant_id}"
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test2" assert response.json()["name"] == "test2"
# verify the participant was updated in transcript # verify the participant was updated in transcript
response = await ac.get(f"/transcripts/{transcript_id}") response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()["participants"]) == 1 assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2" assert response.json()["participants"][0]["name"] == "test2"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_participants_update_speaker(): async def test_transcript_participants_update_speaker(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["participants"] == [] assert response.json()["participants"] == []
transcript_id = response.json()["id"] transcript_id = response.json()["id"]
# create a participant # create a participant
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1}, json={"name": "test", "speaker": 1},
) )
@@ -128,7 +115,7 @@ async def test_transcript_participants_update_speaker():
participant1_id = response.json()["id"] participant1_id = response.json()["id"]
# create another participant # create another participant
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2}, json={"name": "test2", "speaker": 2},
) )
@@ -136,27 +123,27 @@ async def test_transcript_participants_update_speaker():
participant2_id = response.json()["id"] participant2_id = response.json()["id"]
# update the participant, refused as speaker is already taken # update the participant, refused as speaker is already taken
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}", f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1}, json={"speaker": 1},
) )
assert response.status_code == 400 assert response.status_code == 400
# delete the participant 1 # delete the participant 1
response = await ac.delete( response = await client.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}" f"/transcripts/{transcript_id}/participants/{participant1_id}"
) )
assert response.status_code == 200 assert response.status_code == 200
# update the participant 2 again, should be accepted now # update the participant 2 again, should be accepted now
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}", f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1}, json={"speaker": 1},
) )
assert response.status_code == 200 assert response.status_code == 200
# ensure participant2 name is still there # ensure participant2 name is still there
response = await ac.get( response = await client.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}" f"/transcripts/{transcript_id}/participants/{participant2_id}"
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -2,7 +2,25 @@ import asyncio
import time import time
import pytest import pytest
from httpx import AsyncClient from httpx import ASGITransport, AsyncClient
@pytest.fixture
async def app_lifespan():
from asgi_lifespan import LifespanManager
from reflector.app import app
async with LifespanManager(app) as manager:
yield manager.app
@pytest.fixture
async def client(app_lifespan):
yield AsyncClient(
transport=ASGITransport(app=app_lifespan),
base_url="http://test/v1",
)
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -11,23 +29,21 @@ from httpx import AsyncClient
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_process( async def test_transcript_process(
tmpdir, tmpdir,
whisper_transcript,
dummy_llm, dummy_llm,
dummy_processors, dummy_processors,
dummy_diarization, dummy_diarization,
dummy_storage, dummy_storage,
client,
): ):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript # create a transcript
response = await ac.post("/transcripts", json={"name": "test"}) response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "idle" assert response.json()["status"] == "idle"
tid = response.json()["id"] tid = response.json()["id"]
# upload mp3 # upload mp3
response = await ac.post( response = await client.post(
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1", f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
files={ files={
"chunk": ( "chunk": (
@@ -45,7 +61,7 @@ async def test_transcript_process(
start_time = time.monotonic() start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
@@ -54,7 +70,7 @@ async def test_transcript_process(
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds") pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# restart the processing # restart the processing
response = await ac.post( response = await client.post(
f"/transcripts/{tid}/process", f"/transcripts/{tid}/process",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -65,7 +81,7 @@ async def test_transcript_process(
start_time = time.monotonic() start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
@@ -80,7 +96,7 @@ async def test_transcript_process(
assert transcript["title"] == "Llm Title" assert transcript["title"] == "Llm Title"
# check topics and transcript # check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics") response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 1 assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"] assert "want to share" in response.json()[0]["transcript"]

View File

@@ -10,7 +10,6 @@ import time
from pathlib import Path from pathlib import Path
import pytest import pytest
from httpx import AsyncClient
from httpx_ws import aconnect_ws from httpx_ws import aconnect_ws
from uvicorn import Config, Server from uvicorn import Config, Server
@@ -50,23 +49,69 @@ class ThreadedUvicorn:
@pytest.fixture @pytest.fixture
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker): def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
import threading
from reflector.app import app from reflector.app import app
from reflector.db import get_database
from reflector.settings import settings from reflector.settings import settings
DATA_DIR = settings.DATA_DIR DATA_DIR = settings.DATA_DIR
settings.DATA_DIR = Path(tmpdir) settings.DATA_DIR = Path(tmpdir)
# start server # start server in a separate thread with its own event loop
host = "127.0.0.1" host = "127.0.0.1"
port = 1255 port = 1255
config = Config(app=app, host=host, port=port) server_started = threading.Event()
server = ThreadedUvicorn(config) server_exception = None
await server.start() server_instance = None
yield (server, host, port) def run_server():
nonlocal server_exception, server_instance
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
# Initialize database connection in this event loop
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait for server to start
server_started.wait(timeout=30)
if server_exception:
raise server_exception
# Wait a bit more for the server to be fully ready
time.sleep(1)
yield server_instance, host, port
# Stop server
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
server.stop()
settings.DATA_DIR = DATA_DIR settings.DATA_DIR = DATA_DIR
@@ -89,6 +134,7 @@ async def test_transcript_rtc_and_websocket(
dummy_storage, dummy_storage,
fake_mp3_upload, fake_mp3_upload,
appserver, appserver,
client,
): ):
# goal: start the server, exchange RTC, receive websocket events # goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread # because of that, we need to start the server in a thread
@@ -97,8 +143,7 @@ async def test_transcript_rtc_and_websocket(
# create a transcript # create a transcript
base_url = f"http://{host}:{port}/v1" base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url) response = await client.post("/transcripts", json={"name": "Test RTC"})
response = await ac.post("/transcripts", json={"name": "Test RTC"})
assert response.status_code == 200 assert response.status_code == 200
tid = response.json()["id"] tid = response.json()["id"]
@@ -143,11 +188,11 @@ async def test_transcript_rtc_and_websocket(
url = f"{base_url}/transcripts/{tid}/record/webrtc" url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav" path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix()) stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start() await stream_client.start()
timeout = 20 timeout = 120
while not client.is_ended(): while not stream_client.is_ended():
await asyncio.sleep(1) await asyncio.sleep(1)
timeout -= 1 timeout -= 1
if timeout < 0: if timeout < 0:
@@ -155,14 +200,14 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection # XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP # instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"})) stream_client.channel.send(json.dumps({"cmd": "STOP"}))
await client.stop() await stream_client.stop()
# wait the processing to finish # wait the processing to finish
timeout = 20 timeout = 120
while True: while True:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
@@ -215,7 +260,7 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("WAVEFORM")] ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list) assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250 assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform") waveform_resp = await client.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200 assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json" assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list) assert isinstance(waveform_resp.json()["data"], list)
@@ -235,7 +280,7 @@ async def test_transcript_rtc_and_websocket(
assert "DURATION" in eventnames assert "DURATION" in eventnames
# check that audio/mp3 is available # check that audio/mp3 is available
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3") audio_resp = await client.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200 assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg" assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@@ -254,6 +299,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_storage, dummy_storage,
fake_mp3_upload, fake_mp3_upload,
appserver, appserver,
client,
): ):
# goal: start the server, exchange RTC, receive websocket events # goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread # because of that, we need to start the server in a thread
@@ -263,8 +309,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
# create a transcript # create a transcript
base_url = f"http://{host}:{port}/v1" base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url) response = await client.post(
response = await ac.post(
"/transcripts", json={"name": "Test RTC", "target_language": "fr"} "/transcripts", json={"name": "Test RTC", "target_language": "fr"}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -311,11 +356,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
url = f"{base_url}/transcripts/{tid}/record/webrtc" url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav" path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix()) stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start() await stream_client.start()
timeout = 20 timeout = 120
while not client.is_ended(): while not stream_client.is_ended():
await asyncio.sleep(1) await asyncio.sleep(1)
timeout -= 1 timeout -= 1
if timeout < 0: if timeout < 0:
@@ -323,18 +368,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
# XXX aiortc is long to close the connection # XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP # instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"})) stream_client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish # wait the processing to finish
await asyncio.sleep(2) await asyncio.sleep(2)
await client.stop() await stream_client.stop()
# wait the processing to finish # wait the processing to finish
timeout = 20 timeout = 120
while True: while True:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] == "ended": if resp.json()["status"] == "ended":
break break

View File

@@ -1,20 +1,16 @@
import pytest import pytest
from httpx import AsyncClient
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_reassign_speaker(fake_transcript_with_topics): async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists # check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}") response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200 assert response.status_code == 200
# check initial topics of the transcript # check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -31,7 +27,7 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert topics[1]["segments"][0]["speaker"] == 0 assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker # reassign speaker
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"speaker": 1, "speaker": 1,
@@ -42,7 +38,7 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert response.status_code == 200 assert response.status_code == 200
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -59,7 +55,7 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert topics[1]["segments"][0]["speaker"] == 0 assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker, middle of 2 topics # reassign speaker, middle of 2 topics
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"speaker": 2, "speaker": 2,
@@ -70,7 +66,7 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert response.status_code == 200 assert response.status_code == 200
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -89,7 +85,7 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert topics[1]["segments"][1]["speaker"] == 0 assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything # reassign speaker, everything
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"speaker": 4, "speaker": 4,
@@ -100,7 +96,7 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert response.status_code == 200 assert response.status_code == 200
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -118,18 +114,15 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_merge_speaker(fake_transcript_with_topics): async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists # check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}") response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200 assert response.status_code == 200
# check initial topics of the transcript # check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -141,7 +134,7 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics):
assert topics[1]["words"][1]["speaker"] == 0 assert topics[1]["words"][1]["speaker"] == 0
# reassign speaker # reassign speaker
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"speaker": 1, "speaker": 1,
@@ -152,7 +145,7 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics):
assert response.status_code == 200 assert response.status_code == 200
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -164,7 +157,7 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics):
assert topics[1]["words"][1]["speaker"] == 0 assert topics[1]["words"][1]["speaker"] == 0
# merge speakers # merge speakers
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/merge", f"/transcripts/{transcript_id}/speaker/merge",
json={ json={
"speaker_from": 1, "speaker_from": 1,
@@ -174,7 +167,7 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics):
assert response.status_code == 200 assert response.status_code == 200
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -187,20 +180,19 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_reassign_with_participant(fake_transcript_with_topics): async def test_transcript_reassign_with_participant(
from reflector.app import app fake_transcript_with_topics, client
):
transcript_id = fake_transcript_with_topics.id transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists # check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}") response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200 assert response.status_code == 200
transcript = response.json() transcript = response.json()
assert len(transcript["participants"]) == 0 assert len(transcript["participants"]) == 0
# create 2 participants # create 2 participants
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={ json={
"name": "Participant 1", "name": "Participant 1",
@@ -209,7 +201,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert response.status_code == 200 assert response.status_code == 200
participant1_id = response.json()["id"] participant1_id = response.json()["id"]
response = await ac.post( response = await client.post(
f"/transcripts/{transcript_id}/participants", f"/transcripts/{transcript_id}/participants",
json={ json={
"name": "Participant 2", "name": "Participant 2",
@@ -219,7 +211,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
participant2_id = response.json()["id"] participant2_id = response.json()["id"]
# check participants speakers # check participants speakers
response = await ac.get(f"/transcripts/{transcript_id}/participants") response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200 assert response.status_code == 200
participants = response.json() participants = response.json()
assert len(participants) == 2 assert len(participants) == 2
@@ -229,7 +221,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert participants[1]["speaker"] is None assert participants[1]["speaker"] is None
# check initial topics of the transcript # check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -246,7 +238,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert topics[1]["segments"][0]["speaker"] == 0 assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker from a participant # reassign speaker from a participant
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"participant": participant1_id, "participant": participant1_id,
@@ -258,7 +250,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
# check participants if speaker has been assigned # check participants if speaker has been assigned
# first participant should have 1, because it's not used yet. # first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants") response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200 assert response.status_code == 200
participants = response.json() participants = response.json()
assert len(participants) == 2 assert len(participants) == 2
@@ -268,7 +260,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert participants[1]["speaker"] is None assert participants[1]["speaker"] is None
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -285,7 +277,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert topics[1]["segments"][0]["speaker"] == 0 assert topics[1]["segments"][0]["speaker"] == 0
# reassign participant, middle of 2 topics # reassign participant, middle of 2 topics
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"participant": participant2_id, "participant": participant2_id,
@@ -297,7 +289,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
# check participants if speaker has been assigned # check participants if speaker has been assigned
# first participant should have 1, because it's not used yet. # first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants") response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200 assert response.status_code == 200
participants = response.json() participants = response.json()
assert len(participants) == 2 assert len(participants) == 2
@@ -307,7 +299,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert participants[1]["speaker"] == 2 assert participants[1]["speaker"] == 2
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -326,7 +318,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert topics[1]["segments"][1]["speaker"] == 0 assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything # reassign speaker, everything
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"participant": participant1_id, "participant": participant1_id,
@@ -337,7 +329,7 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
assert response.status_code == 200 assert response.status_code == 200
# check topics again # check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200 assert response.status_code == 200
topics = response.json() topics = response.json()
assert len(topics) == 2 assert len(topics) == 2
@@ -355,20 +347,17 @@ async def test_transcript_reassign_with_participant(fake_transcript_with_topics)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics): async def test_transcript_reassign_edge_cases(fake_transcript_with_topics, client):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists # check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}") response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200 assert response.status_code == 200
transcript = response.json() transcript = response.json()
assert len(transcript["participants"]) == 0 assert len(transcript["participants"]) == 0
# try reassign without any participant_id or speaker # try reassign without any participant_id or speaker
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"timestamp_from": 0, "timestamp_from": 0,
@@ -378,7 +367,7 @@ async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
assert response.status_code == 400 assert response.status_code == 400
# try reassing with both participant_id and speaker # try reassing with both participant_id and speaker
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"participant": "123", "participant": "123",
@@ -390,7 +379,7 @@ async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
assert response.status_code == 400 assert response.status_code == 400
# try reassing with non-existing participant_id # try reassing with non-existing participant_id
response = await ac.patch( response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign", f"/transcripts/{transcript_id}/speaker/assign",
json={ json={
"participant": "123", "participant": "123",

View File

@@ -1,22 +1,18 @@
import pytest import pytest
from httpx import AsyncClient
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_topics(fake_transcript_with_topics): async def test_transcript_topics(fake_transcript_with_topics, client):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists # check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}/topics") response = await client.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 2 assert len(response.json()) == 2
topic_id = response.json()[0]["id"] topic_id = response.json()[0]["id"]
# get words per speakers # get words per speakers
response = await ac.get( response = await client.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker" f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -1,20 +1,16 @@
import pytest import pytest
from httpx import AsyncClient
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_create_default_translation(): async def test_transcript_create_default_translation(client):
from reflector.app import app response = await client.post("/transcripts", json={"name": "test en"})
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test en"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test en" assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en" assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en" assert response.json()["target_language"] == "en"
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test en" assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en" assert response.json()["source_language"] == "en"
@@ -22,11 +18,8 @@ async def test_transcript_create_default_translation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_create_en_fr_translation(): async def test_transcript_create_en_fr_translation(client):
from reflector.app import app response = await client.post(
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test en/fr", "target_language": "fr"} "/transcripts", json={"name": "test en/fr", "target_language": "fr"}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -35,7 +28,7 @@ async def test_transcript_create_en_fr_translation():
assert response.json()["target_language"] == "fr" assert response.json()["target_language"] == "fr"
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test en/fr" assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en" assert response.json()["source_language"] == "en"
@@ -43,11 +36,8 @@ async def test_transcript_create_en_fr_translation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_create_fr_en_translation(): async def test_transcript_create_fr_en_translation(client):
from reflector.app import app response = await client.post(
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test fr/en", "source_language": "fr"} "/transcripts", json={"name": "test fr/en", "source_language": "fr"}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -56,7 +46,7 @@ async def test_transcript_create_fr_en_translation():
assert response.json()["target_language"] == "en" assert response.json()["target_language"] == "en"
tid = response.json()["id"] tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}") response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "test fr/en" assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr" assert response.json()["source_language"] == "fr"

View File

@@ -2,7 +2,6 @@ import asyncio
import time import time
import pytest import pytest
from httpx import AsyncClient
@pytest.mark.usefixtures("setup_database") @pytest.mark.usefixtures("setup_database")
@@ -15,19 +14,16 @@ async def test_transcript_upload_file(
dummy_processors, dummy_processors,
dummy_diarization, dummy_diarization,
dummy_storage, dummy_storage,
client,
): ):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript # create a transcript
response = await ac.post("/transcripts", json={"name": "test"}) response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "idle" assert response.json()["status"] == "idle"
tid = response.json()["id"] tid = response.json()["id"]
# upload mp3 # upload mp3
response = await ac.post( response = await client.post(
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1", f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
files={ files={
"chunk": ( "chunk": (
@@ -45,7 +41,7 @@ async def test_transcript_upload_file(
start_time = time.monotonic() start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
@@ -60,7 +56,7 @@ async def test_transcript_upload_file(
assert transcript["title"] == "Llm Title" assert transcript["title"] == "Llm Title"
# check topics and transcript # check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics") response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 1 assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"] assert "want to share" in response.json()[0]["transcript"]

View File

@@ -2,7 +2,7 @@
import pytest import pytest
from reflector.db import database from reflector.db import get_database
from reflector.db.transcripts import ( from reflector.db.transcripts import (
SourceKind, SourceKind,
TranscriptController, TranscriptController,
@@ -26,7 +26,7 @@ class TestWebVTTAutoUpdate:
) )
try: try:
result = await database.fetch_one( result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id) transcripts.select().where(transcripts.c.id == transcript.id)
) )
@@ -58,7 +58,7 @@ class TestWebVTTAutoUpdate:
await controller.upsert_topic(transcript, topic) await controller.upsert_topic(transcript, topic)
result = await database.fetch_one( result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id) transcripts.select().where(transcripts.c.id == transcript.id)
) )
@@ -99,7 +99,7 @@ class TestWebVTTAutoUpdate:
await controller.update(transcript, {"topics": topics_data}) await controller.update(transcript, {"topics": topics_data})
# Fetch from DB # Fetch from DB
result = await database.fetch_one( result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id) transcripts.select().where(transcripts.c.id == transcript.id)
) )
@@ -141,7 +141,7 @@ class TestWebVTTAutoUpdate:
await controller.update(transcript, values) await controller.update(transcript, values)
# Fetch from DB # Fetch from DB
result = await database.fetch_one( result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id) transcripts.select().where(transcripts.c.id == transcript.id)
) )
@@ -216,7 +216,7 @@ class TestWebVTTAutoUpdate:
await controller.update(transcript, values) await controller.update(transcript, values)
# Fetch from DB # Fetch from DB
result = await database.fetch_one( result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id) transcripts.select().where(transcripts.c.id == transcript.id)
) )

2739
server/uv.lock generated

File diff suppressed because it is too large Load Diff