mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: postgresql migration and removal of sqlite in pytest (#546)
* feat: remove support of sqlite, 100% postgres * fix: more migration and make datetime timezone aware in postgres * fix: change how database is get, and use contextvar to have difference instance between different loops * test: properly use client fixture that handle lifetime/database connection * fix: add missing client fixture parameters to test functions This commit fixes NameError issues where test functions were trying to use the 'client' fixture but didn't have it as a parameter. The changes include: 1. Added 'client' parameter to test functions in: - test_transcripts_audio_download.py (6 functions including fixture) - test_transcripts_speaker.py (3 functions) - test_transcripts_upload.py (1 function) - test_transcripts_rtc_ws.py (2 functions + appserver fixture) 2. Resolved naming conflicts in test_transcripts_rtc_ws.py where both HTTP client and StreamClient were using variable name 'client'. StreamClient instances are now named 'stream_client' to avoid conflicts. 3. Added missing 'from reflector.app import app' import in rtc_ws tests. Background: Previously implemented contextvars solution with get_database() function resolves asyncio event loop conflicts in Celery tasks. The global client fixture was also created to replace manual AsyncClient instances, ensuring proper FastAPI application lifecycle management and database connections during tests. All tests now pass except for 2 pre-existing RTC WebSocket test failures related to asyncpg connection issues unrelated to these fixes. * fix: ensure task are correctly closed * fix: make separate event loop for the live server * fix: make default settings pointing at postgres * build: remove pytest-docker deps out of dev, just tests group
This commit is contained in:
30
.github/workflows/db_migrations.yml
vendored
30
.github/workflows/db_migrations.yml
vendored
@@ -17,10 +17,40 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
test-migrations:
|
test-migrations:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:17
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: reflector
|
||||||
|
POSTGRES_PASSWORD: reflector
|
||||||
|
POSTGRES_DB: reflector
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd pg_isready -h 127.0.0.1 -p 5432
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
|
||||||
|
env:
|
||||||
|
DATABASE_URL: postgresql://reflector:reflector@localhost:5432/reflector
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install PostgreSQL client
|
||||||
|
run: sudo apt-get update && sudo apt-get install -y postgresql-client | cat
|
||||||
|
|
||||||
|
- name: Wait for Postgres
|
||||||
|
run: |
|
||||||
|
for i in {1..30}; do
|
||||||
|
if pg_isready -h localhost -p 5432; then
|
||||||
|
echo "Postgres is ready"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo "Waiting for Postgres... ($i)" && sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v3
|
uses: astral-sh/setup-uv@v3
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def upgrade() -> None:
|
|||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column("user_id", sa.String(), nullable=True),
|
||||||
sa.Column("room_id", sa.String(), nullable=True),
|
sa.Column("room_id", sa.String(), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False
|
"is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||||
),
|
),
|
||||||
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
|
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
@@ -53,12 +53,15 @@ def upgrade() -> None:
|
|||||||
sa.Column("user_id", sa.String(), nullable=False),
|
sa.Column("user_id", sa.String(), nullable=False),
|
||||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"zulip_auto_post", sa.Boolean(), server_default=sa.text("0"), nullable=False
|
"zulip_auto_post",
|
||||||
|
sa.Boolean(),
|
||||||
|
server_default=sa.text("false"),
|
||||||
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("zulip_stream", sa.String(), nullable=True),
|
sa.Column("zulip_stream", sa.String(), nullable=True),
|
||||||
sa.Column("zulip_topic", sa.String(), nullable=True),
|
sa.Column("zulip_topic", sa.String(), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False
|
"is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||||
),
|
),
|
||||||
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
|
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
|
|||||||
@@ -20,11 +20,14 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
sourcekind_enum = sa.Enum("room", "live", "file", name="sourcekind")
|
||||||
|
sourcekind_enum.create(op.get_bind())
|
||||||
|
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"transcript",
|
"transcript",
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"source_kind",
|
"source_kind",
|
||||||
sa.Enum("ROOM", "LIVE", "FILE", name="sourcekind"),
|
sourcekind_enum,
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -43,6 +46,8 @@ def upgrade() -> None:
|
|||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_column("transcript", "source_kind")
|
op.drop_column("transcript", "source_kind")
|
||||||
|
sourcekind_enum = sa.Enum(name="sourcekind")
|
||||||
|
sourcekind_enum.drop(op.get_bind())
|
||||||
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def upgrade() -> None:
|
|||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.execute(
|
op.execute(
|
||||||
"UPDATE transcript SET events = "
|
"UPDATE transcript SET events = "
|
||||||
'REPLACE(events, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\');'
|
'REPLACE(events::text, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\')::json;'
|
||||||
)
|
)
|
||||||
op.alter_column("transcript", "summary", new_column_name="long_summary")
|
op.alter_column("transcript", "summary", new_column_name="long_summary")
|
||||||
op.add_column("transcript", sa.Column("title", sa.String(), nullable=True))
|
op.add_column("transcript", sa.Column("title", sa.String(), nullable=True))
|
||||||
@@ -34,7 +34,7 @@ def downgrade() -> None:
|
|||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.execute(
|
op.execute(
|
||||||
"UPDATE transcript SET events = "
|
"UPDATE transcript SET events = "
|
||||||
'REPLACE(events, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\');'
|
'REPLACE(events::text, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\')::json;'
|
||||||
)
|
)
|
||||||
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||||
batch_op.alter_column("long_summary", nullable=True, new_column_name="summary")
|
batch_op.alter_column("long_summary", nullable=True, new_column_name="summary")
|
||||||
|
|||||||
121
server/migrations/versions/9f5c78d352d6_datetime_timezone.py
Normal file
121
server/migrations/versions/9f5c78d352d6_datetime_timezone.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""datetime timezone
|
||||||
|
|
||||||
|
Revision ID: 9f5c78d352d6
|
||||||
|
Revises: 8120ebc75366
|
||||||
|
Create Date: 2025-08-13 19:18:27.113593
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "9f5c78d352d6"
|
||||||
|
down_revision: Union[str, None] = "8120ebc75366"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"start_date",
|
||||||
|
existing_type=postgresql.TIMESTAMP(),
|
||||||
|
type_=sa.DateTime(timezone=True),
|
||||||
|
existing_nullable=True,
|
||||||
|
)
|
||||||
|
batch_op.alter_column(
|
||||||
|
"end_date",
|
||||||
|
existing_type=postgresql.TIMESTAMP(),
|
||||||
|
type_=sa.DateTime(timezone=True),
|
||||||
|
existing_nullable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"consent_timestamp",
|
||||||
|
existing_type=postgresql.TIMESTAMP(),
|
||||||
|
type_=sa.DateTime(timezone=True),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("recording", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"recorded_at",
|
||||||
|
existing_type=postgresql.TIMESTAMP(),
|
||||||
|
type_=sa.DateTime(timezone=True),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"created_at",
|
||||||
|
existing_type=postgresql.TIMESTAMP(),
|
||||||
|
type_=sa.DateTime(timezone=True),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"created_at",
|
||||||
|
existing_type=postgresql.TIMESTAMP(),
|
||||||
|
type_=sa.DateTime(timezone=True),
|
||||||
|
existing_nullable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table("transcript", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"created_at",
|
||||||
|
existing_type=sa.DateTime(timezone=True),
|
||||||
|
type_=postgresql.TIMESTAMP(),
|
||||||
|
existing_nullable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("room", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"created_at",
|
||||||
|
existing_type=sa.DateTime(timezone=True),
|
||||||
|
type_=postgresql.TIMESTAMP(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("recording", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"recorded_at",
|
||||||
|
existing_type=sa.DateTime(timezone=True),
|
||||||
|
type_=postgresql.TIMESTAMP(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"consent_timestamp",
|
||||||
|
existing_type=sa.DateTime(timezone=True),
|
||||||
|
type_=postgresql.TIMESTAMP(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
"end_date",
|
||||||
|
existing_type=sa.DateTime(timezone=True),
|
||||||
|
type_=postgresql.TIMESTAMP(),
|
||||||
|
existing_nullable=True,
|
||||||
|
)
|
||||||
|
batch_op.alter_column(
|
||||||
|
"start_date",
|
||||||
|
existing_type=sa.DateTime(timezone=True),
|
||||||
|
type_=postgresql.TIMESTAMP(),
|
||||||
|
existing_nullable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -25,7 +25,7 @@ def upgrade() -> None:
|
|||||||
sa.Column(
|
sa.Column(
|
||||||
"is_shared",
|
"is_shared",
|
||||||
sa.Boolean(),
|
sa.Boolean(),
|
||||||
server_default=sa.text("0"),
|
server_default=sa.text("false"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,7 +23,10 @@ def upgrade() -> None:
|
|||||||
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
with op.batch_alter_table("meeting", schema=None) as batch_op:
|
||||||
batch_op.add_column(
|
batch_op.add_column(
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_active", sa.Boolean(), server_default=sa.text("1"), nullable=False
|
"is_active",
|
||||||
|
sa.Boolean(),
|
||||||
|
server_default=sa.text("true"),
|
||||||
|
nullable=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def upgrade() -> None:
|
|||||||
op.add_column(
|
op.add_column(
|
||||||
"transcript",
|
"transcript",
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"reviewed", sa.Boolean(), server_default=sa.text("0"), nullable=False
|
"reviewed", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ tests = [
|
|||||||
"httpx-ws>=0.4.1",
|
"httpx-ws>=0.4.1",
|
||||||
"pytest-httpx>=0.23.1",
|
"pytest-httpx>=0.23.1",
|
||||||
"pytest-celery>=0.0.0",
|
"pytest-celery>=0.0.0",
|
||||||
|
"pytest-docker>=3.2.3",
|
||||||
|
"asgi-lifespan>=2.1.0",
|
||||||
]
|
]
|
||||||
aws = ["aioboto3>=11.2.0"]
|
aws = ["aioboto3>=11.2.0"]
|
||||||
evaluation = [
|
evaluation = [
|
||||||
@@ -86,7 +88,7 @@ source = ["reflector"]
|
|||||||
|
|
||||||
[tool.pytest_env]
|
[tool.pytest_env]
|
||||||
ENVIRONMENT = "pytest"
|
ENVIRONMENT = "pytest"
|
||||||
DATABASE_URL = "sqlite:///test.sqlite"
|
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||||
|
|||||||
@@ -1,12 +1,28 @@
|
|||||||
|
import contextvars
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
database = databases.Database(settings.DATABASE_URL)
|
|
||||||
metadata = sqlalchemy.MetaData()
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
|
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
|
||||||
|
contextvars.ContextVar("database", default=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_database() -> databases.Database:
|
||||||
|
"""Get database instance for current asyncio context"""
|
||||||
|
db = _database_context.get()
|
||||||
|
if db is None:
|
||||||
|
db = databases.Database(settings.DATABASE_URL)
|
||||||
|
_database_context.set(db)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
# import models
|
# import models
|
||||||
import reflector.db.meetings # noqa
|
import reflector.db.meetings # noqa
|
||||||
import reflector.db.recordings # noqa
|
import reflector.db.recordings # noqa
|
||||||
@@ -14,16 +30,18 @@ import reflector.db.rooms # noqa
|
|||||||
import reflector.db.transcripts # noqa
|
import reflector.db.transcripts # noqa
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if "sqlite" in settings.DATABASE_URL:
|
if "postgres" not in settings.DATABASE_URL:
|
||||||
kwargs["connect_args"] = {"check_same_thread": False}
|
raise Exception("Only postgres database is supported in reflector")
|
||||||
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
|
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@subscribers_startup.append
|
@subscribers_startup.append
|
||||||
async def database_connect(_):
|
async def database_connect(_):
|
||||||
|
database = get_database()
|
||||||
await database.connect()
|
await database.connect()
|
||||||
|
|
||||||
|
|
||||||
@subscribers_shutdown.append
|
@subscribers_shutdown.append
|
||||||
async def database_disconnect(_):
|
async def database_disconnect(_):
|
||||||
|
database = get_database()
|
||||||
await database.disconnect()
|
await database.disconnect()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import sqlalchemy as sa
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from reflector.db import database, metadata
|
from reflector.db import get_database, metadata
|
||||||
from reflector.db.rooms import Room
|
from reflector.db.rooms import Room
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
@@ -16,8 +16,8 @@ meetings = sa.Table(
|
|||||||
sa.Column("room_name", sa.String),
|
sa.Column("room_name", sa.String),
|
||||||
sa.Column("room_url", sa.String),
|
sa.Column("room_url", sa.String),
|
||||||
sa.Column("host_room_url", sa.String),
|
sa.Column("host_room_url", sa.String),
|
||||||
sa.Column("start_date", sa.DateTime),
|
sa.Column("start_date", sa.DateTime(timezone=True)),
|
||||||
sa.Column("end_date", sa.DateTime),
|
sa.Column("end_date", sa.DateTime(timezone=True)),
|
||||||
sa.Column("user_id", sa.String),
|
sa.Column("user_id", sa.String),
|
||||||
sa.Column("room_id", sa.String),
|
sa.Column("room_id", sa.String),
|
||||||
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||||
@@ -42,6 +42,12 @@ meetings = sa.Table(
|
|||||||
server_default=sa.true(),
|
server_default=sa.true(),
|
||||||
),
|
),
|
||||||
sa.Index("idx_meeting_room_id", "room_id"),
|
sa.Index("idx_meeting_room_id", "room_id"),
|
||||||
|
sa.Index(
|
||||||
|
"idx_one_active_meeting_per_room",
|
||||||
|
"room_id",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("is_active = true"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
meeting_consent = sa.Table(
|
meeting_consent = sa.Table(
|
||||||
@@ -51,7 +57,7 @@ meeting_consent = sa.Table(
|
|||||||
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
|
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
|
||||||
sa.Column("user_id", sa.String),
|
sa.Column("user_id", sa.String),
|
||||||
sa.Column("consent_given", sa.Boolean, nullable=False),
|
sa.Column("consent_given", sa.Boolean, nullable=False),
|
||||||
sa.Column("consent_timestamp", sa.DateTime, nullable=False),
|
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -111,7 +117,7 @@ class MeetingController:
|
|||||||
recording_trigger=room.recording_trigger,
|
recording_trigger=room.recording_trigger,
|
||||||
)
|
)
|
||||||
query = meetings.insert().values(**meeting.model_dump())
|
query = meetings.insert().values(**meeting.model_dump())
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
return meeting
|
return meeting
|
||||||
|
|
||||||
async def get_all_active(self) -> list[Meeting]:
|
async def get_all_active(self) -> list[Meeting]:
|
||||||
@@ -119,7 +125,7 @@ class MeetingController:
|
|||||||
Get active meetings.
|
Get active meetings.
|
||||||
"""
|
"""
|
||||||
query = meetings.select().where(meetings.c.is_active)
|
query = meetings.select().where(meetings.c.is_active)
|
||||||
return await database.fetch_all(query)
|
return await get_database().fetch_all(query)
|
||||||
|
|
||||||
async def get_by_room_name(
|
async def get_by_room_name(
|
||||||
self,
|
self,
|
||||||
@@ -129,7 +135,7 @@ class MeetingController:
|
|||||||
Get a meeting by room name.
|
Get a meeting by room name.
|
||||||
"""
|
"""
|
||||||
query = meetings.select().where(meetings.c.room_name == room_name)
|
query = meetings.select().where(meetings.c.room_name == room_name)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -151,7 +157,7 @@ class MeetingController:
|
|||||||
)
|
)
|
||||||
.order_by(end_date.desc())
|
.order_by(end_date.desc())
|
||||||
)
|
)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -162,7 +168,7 @@ class MeetingController:
|
|||||||
Get a meeting by id
|
Get a meeting by id
|
||||||
"""
|
"""
|
||||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
return Meeting(**result)
|
return Meeting(**result)
|
||||||
@@ -174,7 +180,7 @@ class MeetingController:
|
|||||||
If not found, it will raise a 404 error.
|
If not found, it will raise a 404 error.
|
||||||
"""
|
"""
|
||||||
query = meetings.select().where(meetings.c.id == meeting_id)
|
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||||
|
|
||||||
@@ -186,7 +192,7 @@ class MeetingController:
|
|||||||
|
|
||||||
async def update_meeting(self, meeting_id: str, **kwargs):
|
async def update_meeting(self, meeting_id: str, **kwargs):
|
||||||
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
|
|
||||||
|
|
||||||
class MeetingConsentController:
|
class MeetingConsentController:
|
||||||
@@ -194,7 +200,7 @@ class MeetingConsentController:
|
|||||||
query = meeting_consent.select().where(
|
query = meeting_consent.select().where(
|
||||||
meeting_consent.c.meeting_id == meeting_id
|
meeting_consent.c.meeting_id == meeting_id
|
||||||
)
|
)
|
||||||
results = await database.fetch_all(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [MeetingConsent(**result) for result in results]
|
return [MeetingConsent(**result) for result in results]
|
||||||
|
|
||||||
async def get_by_meeting_and_user(
|
async def get_by_meeting_and_user(
|
||||||
@@ -205,7 +211,7 @@ class MeetingConsentController:
|
|||||||
meeting_consent.c.meeting_id == meeting_id,
|
meeting_consent.c.meeting_id == meeting_id,
|
||||||
meeting_consent.c.user_id == user_id,
|
meeting_consent.c.user_id == user_id,
|
||||||
)
|
)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if result is None:
|
if result is None:
|
||||||
return None
|
return None
|
||||||
return MeetingConsent(**result) if result else None
|
return MeetingConsent(**result) if result else None
|
||||||
@@ -227,14 +233,14 @@ class MeetingConsentController:
|
|||||||
consent_timestamp=consent.consent_timestamp,
|
consent_timestamp=consent.consent_timestamp,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
|
|
||||||
existing.consent_given = consent.consent_given
|
existing.consent_given = consent.consent_given
|
||||||
existing.consent_timestamp = consent.consent_timestamp
|
existing.consent_timestamp = consent.consent_timestamp
|
||||||
return existing
|
return existing
|
||||||
|
|
||||||
query = meeting_consent.insert().values(**consent.model_dump())
|
query = meeting_consent.insert().values(**consent.model_dump())
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
return consent
|
return consent
|
||||||
|
|
||||||
async def has_any_denial(self, meeting_id: str) -> bool:
|
async def has_any_denial(self, meeting_id: str) -> bool:
|
||||||
@@ -243,7 +249,7 @@ class MeetingConsentController:
|
|||||||
meeting_consent.c.meeting_id == meeting_id,
|
meeting_consent.c.meeting_id == meeting_id,
|
||||||
meeting_consent.c.consent_given.is_(False),
|
meeting_consent.c.consent_given.is_(False),
|
||||||
)
|
)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
return result is not None
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Literal
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from reflector.db import database, metadata
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
recordings = sa.Table(
|
recordings = sa.Table(
|
||||||
@@ -13,7 +13,7 @@ recordings = sa.Table(
|
|||||||
sa.Column("id", sa.String, primary_key=True),
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
sa.Column("bucket_name", sa.String, nullable=False),
|
sa.Column("bucket_name", sa.String, nullable=False),
|
||||||
sa.Column("object_key", sa.String, nullable=False),
|
sa.Column("object_key", sa.String, nullable=False),
|
||||||
sa.Column("recorded_at", sa.DateTime, nullable=False),
|
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"status",
|
"status",
|
||||||
sa.String,
|
sa.String,
|
||||||
@@ -37,12 +37,12 @@ class Recording(BaseModel):
|
|||||||
class RecordingController:
|
class RecordingController:
|
||||||
async def create(self, recording: Recording):
|
async def create(self, recording: Recording):
|
||||||
query = recordings.insert().values(**recording.model_dump())
|
query = recordings.insert().values(**recording.model_dump())
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
return recording
|
return recording
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Recording:
|
async def get_by_id(self, id: str) -> Recording:
|
||||||
query = recordings.select().where(recordings.c.id == id)
|
query = recordings.select().where(recordings.c.id == id)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
return Recording(**result) if result else None
|
return Recording(**result) if result else None
|
||||||
|
|
||||||
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
|
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
|
||||||
@@ -50,7 +50,7 @@ class RecordingController:
|
|||||||
recordings.c.bucket_name == bucket_name,
|
recordings.c.bucket_name == bucket_name,
|
||||||
recordings.c.object_key == object_key,
|
recordings.c.object_key == object_key,
|
||||||
)
|
)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
return Recording(**result) if result else None
|
return Recording(**result) if result else None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from sqlite3 import IntegrityError
|
from sqlite3 import IntegrityError
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@@ -7,7 +7,7 @@ from fastapi import HTTPException
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.sql import false, or_
|
from sqlalchemy.sql import false, or_
|
||||||
|
|
||||||
from reflector.db import database, metadata
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
rooms = sqlalchemy.Table(
|
rooms = sqlalchemy.Table(
|
||||||
@@ -16,7 +16,7 @@ rooms = sqlalchemy.Table(
|
|||||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||||
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
|
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
|
||||||
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
||||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime, nullable=False),
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||||
sqlalchemy.Column(
|
sqlalchemy.Column(
|
||||||
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
|
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
),
|
),
|
||||||
@@ -48,7 +48,7 @@ class Room(BaseModel):
|
|||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
name: str
|
name: str
|
||||||
user_id: str
|
user_id: str
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
zulip_auto_post: bool = False
|
zulip_auto_post: bool = False
|
||||||
zulip_stream: str = ""
|
zulip_stream: str = ""
|
||||||
zulip_topic: str = ""
|
zulip_topic: str = ""
|
||||||
@@ -92,7 +92,7 @@ class RoomController:
|
|||||||
if return_query:
|
if return_query:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
results = await database.fetch_all(query)
|
results = await get_database().fetch_all(query)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
@@ -125,7 +125,7 @@ class RoomController:
|
|||||||
)
|
)
|
||||||
query = rooms.insert().values(**room.model_dump())
|
query = rooms.insert().values(**room.model_dump())
|
||||||
try:
|
try:
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||||
return room
|
return room
|
||||||
@@ -136,7 +136,7 @@ class RoomController:
|
|||||||
"""
|
"""
|
||||||
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
||||||
try:
|
try:
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||||
|
|
||||||
@@ -151,7 +151,7 @@ class RoomController:
|
|||||||
query = rooms.select().where(rooms.c.id == room_id)
|
query = rooms.select().where(rooms.c.id == room_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
return Room(**result)
|
return Room(**result)
|
||||||
@@ -163,7 +163,7 @@ class RoomController:
|
|||||||
query = rooms.select().where(rooms.c.name == room_name)
|
query = rooms.select().where(rooms.c.name == room_name)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
return Room(**result)
|
return Room(**result)
|
||||||
@@ -175,7 +175,7 @@ class RoomController:
|
|||||||
If not found, it will raise a 404 error.
|
If not found, it will raise a 404 error.
|
||||||
"""
|
"""
|
||||||
query = rooms.select().where(rooms.c.id == meeting_id)
|
query = rooms.select().where(rooms.c.id == meeting_id)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
@@ -197,7 +197,7 @@ class RoomController:
|
|||||||
if user_id is not None and room.user_id != user_id:
|
if user_id is not None and room.user_id != user_id:
|
||||||
return
|
return
|
||||||
query = rooms.delete().where(rooms.c.id == room_id)
|
query = rooms.delete().where(rooms.c.id == room_id)
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
|
|
||||||
|
|
||||||
rooms_controller = RoomController()
|
rooms_controller = RoomController()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import sqlalchemy
|
|||||||
import webvtt
|
import webvtt
|
||||||
from pydantic import BaseModel, Field, constr, field_serializer
|
from pydantic import BaseModel, Field, constr, field_serializer
|
||||||
|
|
||||||
from reflector.db import database
|
from reflector.db import get_database
|
||||||
from reflector.db.transcripts import SourceKind, transcripts
|
from reflector.db.transcripts import SourceKind, transcripts
|
||||||
from reflector.db.utils import is_postgresql
|
from reflector.db.utils import is_postgresql
|
||||||
|
|
||||||
@@ -207,12 +207,12 @@ class SearchController:
|
|||||||
.limit(params.limit)
|
.limit(params.limit)
|
||||||
.offset(params.offset)
|
.offset(params.offset)
|
||||||
)
|
)
|
||||||
rs = await database.fetch_all(query)
|
rs = await get_database().fetch_all(query)
|
||||||
|
|
||||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||||
base_query.alias("search_results")
|
base_query.alias("search_results")
|
||||||
)
|
)
|
||||||
total = await database.fetch_val(count_query)
|
total = await get_database().fetch_val(count_query)
|
||||||
|
|
||||||
def _process_result(r) -> SearchResult:
|
def _process_result(r) -> SearchResult:
|
||||||
r_dict: Dict[str, Any] = dict(r)
|
r_dict: Dict[str, Any] = dict(r)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from sqlalchemy import Enum
|
|||||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||||
from sqlalchemy.sql import false, or_
|
from sqlalchemy.sql import false, or_
|
||||||
|
|
||||||
from reflector.db import database, metadata
|
from reflector.db import get_database, metadata
|
||||||
from reflector.db.rooms import rooms
|
from reflector.db.rooms import rooms
|
||||||
from reflector.db.utils import is_postgresql
|
from reflector.db.utils import is_postgresql
|
||||||
from reflector.processors.types import Word as ProcessorWord
|
from reflector.processors.types import Word as ProcessorWord
|
||||||
@@ -41,7 +41,7 @@ transcripts = sqlalchemy.Table(
|
|||||||
sqlalchemy.Column("status", sqlalchemy.String),
|
sqlalchemy.Column("status", sqlalchemy.String),
|
||||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||||
sqlalchemy.Column("duration", sqlalchemy.Float),
|
sqlalchemy.Column("duration", sqlalchemy.Float),
|
||||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
|
||||||
sqlalchemy.Column("title", sqlalchemy.String),
|
sqlalchemy.Column("title", sqlalchemy.String),
|
||||||
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
||||||
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
||||||
@@ -421,7 +421,7 @@ class TranscriptController:
|
|||||||
if return_query:
|
if return_query:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
results = await database.fetch_all(query)
|
results = await get_database().fetch_all(query)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||||
@@ -431,7 +431,7 @@ class TranscriptController:
|
|||||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
return Transcript(**result)
|
return Transcript(**result)
|
||||||
@@ -445,7 +445,7 @@ class TranscriptController:
|
|||||||
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
return Transcript(**result)
|
return Transcript(**result)
|
||||||
@@ -463,7 +463,7 @@ class TranscriptController:
|
|||||||
if order_by.startswith("-"):
|
if order_by.startswith("-"):
|
||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
results = await database.fetch_all(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [Transcript(**result) for result in results]
|
return [Transcript(**result) for result in results]
|
||||||
|
|
||||||
async def get_by_id_for_http(
|
async def get_by_id_for_http(
|
||||||
@@ -481,7 +481,7 @@ class TranscriptController:
|
|||||||
to determine if the user can access the transcript.
|
to determine if the user can access the transcript.
|
||||||
"""
|
"""
|
||||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
result = await database.fetch_one(query)
|
result = await get_database().fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
@@ -534,7 +534,7 @@ class TranscriptController:
|
|||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
query = transcripts.insert().values(**transcript.model_dump())
|
query = transcripts.insert().values(**transcript.model_dump())
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||||
@@ -553,7 +553,7 @@ class TranscriptController:
|
|||||||
.where(transcripts.c.id == transcript.id)
|
.where(transcripts.c.id == transcript.id)
|
||||||
.values(**values)
|
.values(**values)
|
||||||
)
|
)
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
if mutate:
|
if mutate:
|
||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(transcript, key, value)
|
setattr(transcript, key, value)
|
||||||
@@ -595,21 +595,21 @@ class TranscriptController:
|
|||||||
return
|
return
|
||||||
transcript.unlink()
|
transcript.unlink()
|
||||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
|
|
||||||
async def remove_by_recording_id(self, recording_id: str):
|
async def remove_by_recording_id(self, recording_id: str):
|
||||||
"""
|
"""
|
||||||
Remove a transcript by recording_id
|
Remove a transcript by recording_id
|
||||||
"""
|
"""
|
||||||
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
||||||
await database.execute(query)
|
await get_database().execute(query)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def transaction(self):
|
async def transaction(self):
|
||||||
"""
|
"""
|
||||||
A context manager for database transaction
|
A context manager for database transaction
|
||||||
"""
|
"""
|
||||||
async with database.transaction(isolation="serializable"):
|
async with get_database().transaction(isolation="serializable"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
async def append_event(
|
async def append_event(
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""Database utility functions."""
|
"""Database utility functions."""
|
||||||
|
|
||||||
from reflector.db import database
|
from reflector.db import get_database
|
||||||
|
|
||||||
|
|
||||||
def is_postgresql() -> bool:
|
def is_postgresql() -> bool:
|
||||||
return database.url.scheme and database.url.scheme.startswith('postgresql')
|
return get_database().url.scheme and get_database().url.scheme.startswith(
|
||||||
|
"postgresql"
|
||||||
|
)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from celery import chord, current_task, group, shared_task
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from structlog import BoundLogger as Logger
|
from structlog import BoundLogger as Logger
|
||||||
|
|
||||||
from reflector.db import database
|
from reflector.db import get_database
|
||||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||||
from reflector.db.recordings import recordings_controller
|
from reflector.db.recordings import recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -72,6 +72,7 @@ def asynctask(f):
|
|||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
async def run_with_db():
|
async def run_with_db():
|
||||||
|
database = get_database()
|
||||||
await database.connect()
|
await database.connect()
|
||||||
try:
|
try:
|
||||||
return await f(*args, **kwargs)
|
return await f(*args, **kwargs)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ This script is used to generate a summary of a meeting notes transcript.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Type, TypeVar
|
from typing import Type, TypeVar
|
||||||
@@ -474,7 +474,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.save:
|
if args.save:
|
||||||
# write the summary to a file, on the format summary-<iso date>.md
|
# write the summary to a file, on the format summary-<iso date>.md
|
||||||
filename = f"summary-{datetime.now().isoformat()}.md"
|
filename = f"summary-{datetime.now(timezone.utc).isoformat()}.md"
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
f.write(sm.as_markdown())
|
f.write(sm.as_markdown())
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ class Settings(BaseSettings):
|
|||||||
CORS_ALLOW_CREDENTIALS: bool = False
|
CORS_ALLOW_CREDENTIALS: bool = False
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
DATABASE_URL: str = (
|
||||||
|
"postgresql+asyncpg://reflector:reflector@localhost:5432/reflector"
|
||||||
|
)
|
||||||
|
|
||||||
# local data directory
|
# local data directory
|
||||||
DATA_DIR: str = "./data"
|
DATA_DIR: str = "./data"
|
||||||
|
|||||||
@@ -9,8 +9,9 @@ async def export_db(filename: str) -> None:
|
|||||||
filename = pathlib.Path(filename).resolve()
|
filename = pathlib.Path(filename).resolve()
|
||||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||||
|
|
||||||
from reflector.db import database, transcripts
|
from reflector.db import get_database, transcripts
|
||||||
|
|
||||||
|
database = get_database()
|
||||||
await database.connect()
|
await database.connect()
|
||||||
transcripts = await database.fetch_all(transcripts.select())
|
transcripts = await database.fetch_all(transcripts.select())
|
||||||
await database.disconnect()
|
await database.disconnect()
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ async def export_db(filename: str) -> None:
|
|||||||
filename = pathlib.Path(filename).resolve()
|
filename = pathlib.Path(filename).resolve()
|
||||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||||
|
|
||||||
from reflector.db import database, transcripts
|
from reflector.db import get_database, transcripts
|
||||||
|
|
||||||
|
database = get_database()
|
||||||
await database.connect()
|
await database.connect()
|
||||||
transcripts = await database.fetch_all(transcripts.select())
|
transcripts = await database.fetch_all(transcripts.select())
|
||||||
await database.disconnect()
|
await database.disconnect()
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ async def process_audio_file_with_diarization(
|
|||||||
|
|
||||||
# For Modal backend, we need to upload the file to S3 first
|
# For Modal backend, we need to upload the file to S3 first
|
||||||
if diarization_backend == "modal":
|
if diarization_backend == "modal":
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from reflector.storage import get_transcripts_storage
|
from reflector.storage import get_transcripts_storage
|
||||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||||
@@ -163,7 +163,7 @@ async def process_audio_file_with_diarization(
|
|||||||
storage = get_transcripts_storage()
|
storage = get_transcripts_storage()
|
||||||
|
|
||||||
# Generate a unique filename in evaluation folder
|
# Generate a unique filename in evaluation folder
|
||||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||||
|
|
||||||
# Use context manager for automatic cleanup
|
# Use context manager for automatic cleanup
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
@@ -35,7 +35,7 @@ async def meeting_audio_consent(
|
|||||||
meeting_id=meeting_id,
|
meeting_id=meeting_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
consent_given=request.consent_given,
|
consent_given=request.consent_given,
|
||||||
consent_timestamp=datetime.utcnow(),
|
consent_timestamp=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_consent = await meeting_consent_controller.upsert(consent)
|
updated_consent = await meeting_consent_controller.upsert(consent)
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Annotated, Literal, Optional
|
from typing import Annotated, Literal, Optional
|
||||||
|
|
||||||
import asyncpg.exceptions
|
import asyncpg.exceptions
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi_pagination import Page
|
from fastapi_pagination import Page
|
||||||
from fastapi_pagination.ext.databases import paginate
|
from fastapi_pagination.ext.databases import apaginate
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import database
|
from reflector.db import get_database
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
@@ -21,6 +21,14 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||||
|
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||||
|
dt = datetime.fromisoformat(iso_string)
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = dt.replace(tzinfo=timezone.utc)
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
class Room(BaseModel):
|
class Room(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -83,8 +91,8 @@ async def rooms_list(
|
|||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
return await paginate(
|
return await apaginate(
|
||||||
database,
|
get_database(),
|
||||||
await rooms_controller.get_all(
|
await rooms_controller.get_all(
|
||||||
user_id=user_id, order_by="-created_at", return_query=True
|
user_id=user_id, order_by="-created_at", return_query=True
|
||||||
),
|
),
|
||||||
@@ -150,7 +158,7 @@ async def rooms_create_meeting(
|
|||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
current_time = datetime.utcnow()
|
current_time = datetime.now(timezone.utc)
|
||||||
meeting = await meetings_controller.get_active(room=room, current_time=current_time)
|
meeting = await meetings_controller.get_active(room=room, current_time=current_time)
|
||||||
|
|
||||||
if meeting is None:
|
if meeting is None:
|
||||||
@@ -166,8 +174,8 @@ async def rooms_create_meeting(
|
|||||||
room_name=whereby_meeting["roomName"],
|
room_name=whereby_meeting["roomName"],
|
||||||
room_url=whereby_meeting["roomUrl"],
|
room_url=whereby_meeting["roomUrl"],
|
||||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||||
start_date=datetime.fromisoformat(whereby_meeting["startDate"]),
|
start_date=parse_datetime_with_timezone(whereby_meeting["startDate"]),
|
||||||
end_date=datetime.fromisoformat(whereby_meeting["endDate"]),
|
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
room=room,
|
room=room,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ from typing import Annotated, Literal, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from fastapi_pagination import Page
|
from fastapi_pagination import Page
|
||||||
from fastapi_pagination.ext.databases import paginate
|
from fastapi_pagination.ext.databases import apaginate
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_serializer
|
from pydantic import BaseModel, Field, field_serializer
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import database
|
from reflector.db import get_database
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.search import (
|
from reflector.db.search import (
|
||||||
@@ -48,7 +48,7 @@ DOWNLOAD_EXPIRE_MINUTES = 60
|
|||||||
|
|
||||||
def create_access_token(data: dict, expires_delta: timedelta):
|
def create_access_token(data: dict, expires_delta: timedelta):
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
to_encode.update({"exp": expire})
|
to_encode.update({"exp": expire})
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
@@ -141,8 +141,8 @@ async def transcripts_list(
|
|||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
return await paginate(
|
return await apaginate(
|
||||||
database,
|
get_database(),
|
||||||
await transcripts_controller.get_all(
|
await transcripts_controller.get_all(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
source_kind=SourceKind(source_kind) if source_kind else None,
|
source_kind=SourceKind(source_kind) if source_kind else None,
|
||||||
|
|||||||
@@ -21,6 +21,14 @@ from reflector.whereby import get_room_sessions
|
|||||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
||||||
|
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
|
||||||
|
dt = datetime.fromisoformat(iso_string)
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = dt.replace(tzinfo=timezone.utc)
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def process_messages():
|
def process_messages():
|
||||||
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
|
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
|
||||||
@@ -69,7 +77,7 @@ async def process_recording(bucket_name: str, object_key: str):
|
|||||||
|
|
||||||
# extract a guid and a datetime from the object key
|
# extract a guid and a datetime from the object key
|
||||||
room_name = f"/{object_key[:36]}"
|
room_name = f"/{object_key[:36]}"
|
||||||
recorded_at = datetime.fromisoformat(object_key[37:57])
|
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_room_name(room_name)
|
meeting = await meetings_controller.get_by_room_name(room_name)
|
||||||
room = await rooms_controller.get_by_id(meeting.room_id)
|
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ class RedisPubSubManager:
|
|||||||
class WebsocketManager:
|
class WebsocketManager:
|
||||||
def __init__(self, pubsub_client: RedisPubSubManager = None):
|
def __init__(self, pubsub_client: RedisPubSubManager = None):
|
||||||
self.rooms: dict = {}
|
self.rooms: dict = {}
|
||||||
|
self.tasks: dict = {}
|
||||||
self.pubsub_client = pubsub_client
|
self.pubsub_client = pubsub_client
|
||||||
|
|
||||||
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||||
@@ -74,13 +75,17 @@ class WebsocketManager:
|
|||||||
|
|
||||||
await self.pubsub_client.connect()
|
await self.pubsub_client.connect()
|
||||||
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
||||||
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
task = asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
||||||
|
self.tasks[id(websocket)] = task
|
||||||
|
|
||||||
async def send_json(self, room_id: str, message: dict) -> None:
|
async def send_json(self, room_id: str, message: dict) -> None:
|
||||||
await self.pubsub_client.send_json(room_id, message)
|
await self.pubsub_client.send_json(room_id, message)
|
||||||
|
|
||||||
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
||||||
self.rooms[room_id].remove(websocket)
|
self.rooms[room_id].remove(websocket)
|
||||||
|
task = self.tasks.pop(id(websocket), None)
|
||||||
|
if task:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
if len(self.rooms[room_id]) == 0:
|
if len(self.rooms[room_id]) == 0:
|
||||||
del self.rooms[room_id]
|
del self.rooms[room_id]
|
||||||
|
|||||||
@@ -1,17 +1,63 @@
|
|||||||
|
import os
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# Pytest-docker configuration
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def docker_compose_file(pytestconfig):
|
||||||
|
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def postgres_service(docker_ip, docker_services):
|
||||||
|
"""Ensure that PostgreSQL service is up and responsive."""
|
||||||
|
port = docker_services.port_for("postgres_test", 5432)
|
||||||
|
|
||||||
|
def is_responsive():
|
||||||
|
try:
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=docker_ip,
|
||||||
|
port=port,
|
||||||
|
dbname="reflector_test",
|
||||||
|
user="test_user",
|
||||||
|
password="test_password",
|
||||||
|
)
|
||||||
|
conn.close()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
|
||||||
|
|
||||||
|
# Return connection parameters
|
||||||
|
return {
|
||||||
|
"host": docker_ip,
|
||||||
|
"port": port,
|
||||||
|
"dbname": "reflector_test",
|
||||||
|
"user": "test_user",
|
||||||
|
"password": "test_password",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def setup_database():
|
async def setup_database(postgres_service):
|
||||||
from reflector.db import engine, metadata # noqa
|
from reflector.db import engine, metadata, get_database # noqa
|
||||||
|
|
||||||
metadata.drop_all(bind=engine)
|
metadata.drop_all(bind=engine)
|
||||||
metadata.create_all(bind=engine)
|
metadata.create_all(bind=engine)
|
||||||
yield
|
database = get_database()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await database.connect()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -46,6 +92,20 @@ def dummy_processors():
|
|||||||
) # noqa
|
) # noqa
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def whisper_transcript():
|
||||||
|
from reflector.processors.audio_transcript_whisper import (
|
||||||
|
AudioTranscriptWhisperProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"reflector.processors.audio_transcript_auto"
|
||||||
|
".AudioTranscriptAutoProcessor.__new__"
|
||||||
|
) as mock_audio:
|
||||||
|
mock_audio.return_value = AudioTranscriptWhisperProcessor()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def dummy_transcript():
|
async def dummy_transcript():
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
@@ -181,6 +241,16 @@ def celery_includes():
|
|||||||
return ["reflector.pipelines.main_live_pipeline"]
|
return ["reflector.pipelines.main_live_pipeline"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def fake_mp3_upload():
|
def fake_mp3_upload():
|
||||||
with patch(
|
with patch(
|
||||||
@@ -191,13 +261,10 @@ def fake_mp3_upload():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def fake_transcript_with_topics(tmpdir):
|
async def fake_transcript_with_topics(tmpdir, client):
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
from reflector.app import app
|
|
||||||
from reflector.db.transcripts import TranscriptTopic
|
from reflector.db.transcripts import TranscriptTopic
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
@@ -206,8 +273,7 @@ async def fake_transcript_with_topics(tmpdir):
|
|||||||
settings.DATA_DIR = Path(tmpdir)
|
settings.DATA_DIR = Path(tmpdir)
|
||||||
|
|
||||||
# create a transcript
|
# create a transcript
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
response = await client.post("/transcripts", json={"name": "Test audio download"})
|
||||||
response = await ac.post("/transcripts", json={"name": "Test audio download"})
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
tid = response.json()["id"]
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
|||||||
13
server/tests/docker-compose.test.yml
Normal file
13
server/tests/docker-compose.test.yml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
version: '3.8'
|
||||||
|
services:
|
||||||
|
postgres_test:
|
||||||
|
image: postgres:15
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: reflector_test
|
||||||
|
POSTGRES_USER: test_user
|
||||||
|
POSTGRES_PASSWORD: test_password
|
||||||
|
ports:
|
||||||
|
- "15432:5432"
|
||||||
|
command: postgres -c fsync=off -c synchronous_commit=off -c full_page_writes=off
|
||||||
|
tmpfs:
|
||||||
|
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g
|
||||||
@@ -1,82 +1,62 @@
|
|||||||
"""Tests for full-text search functionality."""
|
"""Tests for full-text search functionality."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from reflector.db import database
|
from reflector.db import get_database
|
||||||
from reflector.db.search import SearchParameters, search_controller
|
from reflector.db.search import SearchParameters, search_controller
|
||||||
from reflector.db.transcripts import transcripts
|
from reflector.db.transcripts import transcripts
|
||||||
from reflector.db.utils import is_postgresql
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_postgresql_only():
|
async def test_search_postgresql_only():
|
||||||
await database.connect()
|
params = SearchParameters(query_text="any query here")
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
params = SearchParameters(query_text="any query here")
|
SearchParameters(query_text="")
|
||||||
results, total = await search_controller.search_transcripts(params)
|
assert False, "Should have raised validation error"
|
||||||
assert results == []
|
except ValidationError:
|
||||||
assert total == 0
|
pass # Expected
|
||||||
|
|
||||||
try:
|
# Test that whitespace query raises validation error
|
||||||
SearchParameters(query_text="")
|
try:
|
||||||
assert False, "Should have raised validation error"
|
SearchParameters(query_text=" ")
|
||||||
except ValidationError:
|
assert False, "Should have raised validation error"
|
||||||
pass # Expected
|
except ValidationError:
|
||||||
|
pass # Expected
|
||||||
# Test that whitespace query raises validation error
|
|
||||||
try:
|
|
||||||
SearchParameters(query_text=" ")
|
|
||||||
assert False, "Should have raised validation error"
|
|
||||||
except ValidationError:
|
|
||||||
pass # Expected
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await database.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_input_validation():
|
async def test_search_input_validation():
|
||||||
await database.connect()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
SearchParameters(query_text="")
|
||||||
SearchParameters(query_text="")
|
assert False, "Should have raised ValidationError"
|
||||||
assert False, "Should have raised ValidationError"
|
except ValidationError:
|
||||||
except ValidationError:
|
pass # Expected
|
||||||
pass # Expected
|
|
||||||
|
|
||||||
# Test that whitespace query raises validation error
|
# Test that whitespace query raises validation error
|
||||||
try:
|
try:
|
||||||
SearchParameters(query_text=" \t\n ")
|
SearchParameters(query_text=" \t\n ")
|
||||||
assert False, "Should have raised ValidationError"
|
assert False, "Should have raised ValidationError"
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
pass # Expected
|
pass # Expected
|
||||||
finally:
|
|
||||||
await database.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgresql_search_with_data():
|
async def test_postgresql_search_with_data():
|
||||||
"""Test full-text search with actual data in PostgreSQL.
|
|
||||||
|
|
||||||
Example how to run: DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector_test uv run pytest tests/test_search.py::test_postgresql_search_with_data -v -p no:env
|
|
||||||
"""
|
|
||||||
# Skip if not PostgreSQL
|
|
||||||
if not is_postgresql():
|
|
||||||
pytest.skip("Test requires PostgreSQL. Set DATABASE_URL=postgresql://...")
|
|
||||||
|
|
||||||
await database.connect()
|
|
||||||
|
|
||||||
# collision is improbable
|
# collision is improbable
|
||||||
test_id = "test-search-e2e-7f3a9b2c"
|
test_id = "test-search-e2e-7f3a9b2c"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
"id": test_id,
|
"id": test_id,
|
||||||
@@ -85,7 +65,7 @@ async def test_postgresql_search_with_data():
|
|||||||
"status": "completed",
|
"status": "completed",
|
||||||
"locked": False,
|
"locked": False,
|
||||||
"duration": 1800.0,
|
"duration": 1800.0,
|
||||||
"created_at": datetime.now(),
|
"created_at": datetime.now(timezone.utc),
|
||||||
"short_summary": "Team discussed search implementation",
|
"short_summary": "Team discussed search implementation",
|
||||||
"long_summary": "The engineering team met to plan the search feature",
|
"long_summary": "The engineering team met to plan the search feature",
|
||||||
"topics": json.dumps([]),
|
"topics": json.dumps([]),
|
||||||
@@ -112,7 +92,7 @@ The search feature should support complex queries with ranking.
|
|||||||
We need to implement PostgreSQL tsvector for better performance.""",
|
We need to implement PostgreSQL tsvector for better performance.""",
|
||||||
}
|
}
|
||||||
|
|
||||||
await database.execute(transcripts.insert().values(**test_data))
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
# Test 1: Search for a word in title
|
# Test 1: Search for a word in title
|
||||||
params = SearchParameters(query_text="planning")
|
params = SearchParameters(query_text="planning")
|
||||||
@@ -141,7 +121,6 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||||
assert test_result.status == "completed"
|
assert test_result.status == "completed"
|
||||||
assert test_result.duration == 1800.0
|
assert test_result.duration == 1800.0
|
||||||
assert test_result.source_kind == "room"
|
|
||||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||||
|
|
||||||
# Test 5: Search with OR operator
|
# Test 5: Search with OR operator
|
||||||
@@ -159,5 +138,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
assert found, "Should find test transcript by exact phrase"
|
assert found, "Should find test transcript by exact phrase"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
|
await get_database().execute(
|
||||||
await database.disconnect()
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
|
)
|
||||||
|
await get_database().disconnect()
|
||||||
|
|||||||
@@ -1,147 +1,128 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_create():
|
async def test_transcript_create(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test"
|
||||||
|
assert response.json()["status"] == "idle"
|
||||||
|
assert response.json()["locked"] is False
|
||||||
|
assert response.json()["id"] is not None
|
||||||
|
assert response.json()["created_at"] is not None
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# ensure some fields are not returned
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
assert "topics" not in response.json()
|
||||||
assert response.status_code == 200
|
assert "events" not in response.json()
|
||||||
assert response.json()["name"] == "test"
|
|
||||||
assert response.json()["status"] == "idle"
|
|
||||||
assert response.json()["locked"] is False
|
|
||||||
assert response.json()["id"] is not None
|
|
||||||
assert response.json()["created_at"] is not None
|
|
||||||
|
|
||||||
# ensure some fields are not returned
|
|
||||||
assert "topics" not in response.json()
|
|
||||||
assert "events" not in response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_name():
|
async def test_transcript_get_update_name(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test"
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
tid = response.json()["id"]
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test"
|
|
||||||
|
|
||||||
tid = response.json()["id"]
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test"
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
response = await client.patch(f"/transcripts/{tid}", json={"name": "test2"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "test"
|
assert response.json()["name"] == "test2"
|
||||||
|
|
||||||
response = await ac.patch(f"/transcripts/{tid}", json={"name": "test2"})
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "test2"
|
assert response.json()["name"] == "test2"
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_locked():
|
async def test_transcript_get_update_locked(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["locked"] is False
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
tid = response.json()["id"]
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["locked"] is False
|
|
||||||
|
|
||||||
tid = response.json()["id"]
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["locked"] is False
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
response = await client.patch(f"/transcripts/{tid}", json={"locked": True})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["locked"] is False
|
assert response.json()["locked"] is True
|
||||||
|
|
||||||
response = await ac.patch(f"/transcripts/{tid}", json={"locked": True})
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["locked"] is True
|
assert response.json()["locked"] is True
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["locked"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_summary():
|
async def test_transcript_get_update_summary(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["long_summary"] is None
|
||||||
|
assert response.json()["short_summary"] is None
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
tid = response.json()["id"]
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["long_summary"] is None
|
|
||||||
assert response.json()["short_summary"] is None
|
|
||||||
|
|
||||||
tid = response.json()["id"]
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["long_summary"] is None
|
||||||
|
assert response.json()["short_summary"] is None
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
response = await client.patch(
|
||||||
assert response.status_code == 200
|
f"/transcripts/{tid}",
|
||||||
assert response.json()["long_summary"] is None
|
json={"long_summary": "test_long", "short_summary": "test_short"},
|
||||||
assert response.json()["short_summary"] is None
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["long_summary"] == "test_long"
|
||||||
|
assert response.json()["short_summary"] == "test_short"
|
||||||
|
|
||||||
response = await ac.patch(
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
f"/transcripts/{tid}",
|
assert response.status_code == 200
|
||||||
json={"long_summary": "test_long", "short_summary": "test_short"},
|
assert response.json()["long_summary"] == "test_long"
|
||||||
)
|
assert response.json()["short_summary"] == "test_short"
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["long_summary"] == "test_long"
|
|
||||||
assert response.json()["short_summary"] == "test_short"
|
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["long_summary"] == "test_long"
|
|
||||||
assert response.json()["short_summary"] == "test_short"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_title():
|
async def test_transcript_get_update_title(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["title"] is None
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
tid = response.json()["id"]
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["title"] is None
|
|
||||||
|
|
||||||
tid = response.json()["id"]
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["title"] is None
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
response = await client.patch(f"/transcripts/{tid}", json={"title": "test_title"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["title"] is None
|
assert response.json()["title"] == "test_title"
|
||||||
|
|
||||||
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"})
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["title"] == "test_title"
|
assert response.json()["title"] == "test_title"
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["title"] == "test_title"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcripts_list_anonymous():
|
async def test_transcripts_list_anonymous(client):
|
||||||
# XXX this test is a bit fragile, as it depends on the storage which
|
# XXX this test is a bit fragile, as it depends on the storage which
|
||||||
# is shared between tests
|
# is shared between tests
|
||||||
from reflector.app import app
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
response = await client.get("/transcripts")
|
||||||
response = await ac.get("/transcripts")
|
assert response.status_code == 401
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
# if public mode, it should be allowed
|
# if public mode, it should be allowed
|
||||||
try:
|
try:
|
||||||
settings.PUBLIC_MODE = True
|
settings.PUBLIC_MODE = True
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
response = await client.get("/transcripts")
|
||||||
response = await ac.get("/transcripts")
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
|
||||||
finally:
|
finally:
|
||||||
settings.PUBLIC_MODE = False
|
settings.PUBLIC_MODE = False
|
||||||
|
|
||||||
@@ -197,67 +178,59 @@ async def authenticated_client2():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcripts_list_authenticated(authenticated_client):
|
async def test_transcripts_list_authenticated(authenticated_client, client):
|
||||||
# XXX this test is a bit fragile, as it depends on the storage which
|
# XXX this test is a bit fragile, as it depends on the storage which
|
||||||
# is shared between tests
|
# is shared between tests
|
||||||
from reflector.app import app
|
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
response = await client.post("/transcripts", json={"name": "testxx1"})
|
||||||
response = await ac.post("/transcripts", json={"name": "testxx1"})
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
assert response.json()["name"] == "testxx1"
|
||||||
assert response.json()["name"] == "testxx1"
|
|
||||||
|
|
||||||
response = await ac.post("/transcripts", json={"name": "testxx2"})
|
response = await client.post("/transcripts", json={"name": "testxx2"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "testxx2"
|
assert response.json()["name"] == "testxx2"
|
||||||
|
|
||||||
response = await ac.get("/transcripts")
|
response = await client.get("/transcripts")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()["items"]) >= 2
|
assert len(response.json()["items"]) >= 2
|
||||||
names = [t["name"] for t in response.json()["items"]]
|
names = [t["name"] for t in response.json()["items"]]
|
||||||
assert "testxx1" in names
|
assert "testxx1" in names
|
||||||
assert "testxx2" in names
|
assert "testxx2" in names
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_delete():
|
async def test_transcript_delete(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "testdel1"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "testdel1"
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
tid = response.json()["id"]
|
||||||
response = await ac.post("/transcripts", json={"name": "testdel1"})
|
response = await client.delete(f"/transcripts/{tid}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "testdel1"
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
tid = response.json()["id"]
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
response = await ac.delete(f"/transcripts/{tid}")
|
assert response.status_code == 404
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["status"] == "ok"
|
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_mark_reviewed():
|
async def test_transcript_mark_reviewed(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test"
|
||||||
|
assert response.json()["reviewed"] is False
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
tid = response.json()["id"]
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test"
|
|
||||||
assert response.json()["reviewed"] is False
|
|
||||||
|
|
||||||
tid = response.json()["id"]
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test"
|
||||||
|
assert response.json()["reviewed"] is False
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
response = await client.patch(f"/transcripts/{tid}", json={"reviewed": True})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "test"
|
assert response.json()["reviewed"] is True
|
||||||
assert response.json()["reviewed"] is False
|
|
||||||
|
|
||||||
response = await ac.patch(f"/transcripts/{tid}", json={"reviewed": True})
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["reviewed"] is True
|
assert response.json()["reviewed"] is True
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["reviewed"] is True
|
|
||||||
|
|||||||
@@ -2,20 +2,17 @@ import shutil
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def fake_transcript(tmpdir):
|
async def fake_transcript(tmpdir, client):
|
||||||
from reflector.app import app
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.views.transcripts import transcripts_controller
|
from reflector.views.transcripts import transcripts_controller
|
||||||
|
|
||||||
settings.DATA_DIR = Path(tmpdir)
|
settings.DATA_DIR = Path(tmpdir)
|
||||||
|
|
||||||
# create a transcript
|
# create a transcript
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
response = await client.post("/transcripts", json={"name": "Test audio download"})
|
||||||
response = await ac.post("/transcripts", json={"name": "Test audio download"})
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
tid = response.json()["id"]
|
tid = response.json()["id"]
|
||||||
|
|
||||||
@@ -39,17 +36,17 @@ async def fake_transcript(tmpdir):
|
|||||||
["/mp3", "audio/mpeg"],
|
["/mp3", "audio/mpeg"],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_transcript_audio_download(fake_transcript, url_suffix, content_type):
|
async def test_transcript_audio_download(
|
||||||
from reflector.app import app
|
fake_transcript, url_suffix, content_type, client
|
||||||
|
):
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
response = await client.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
|
||||||
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == content_type
|
assert response.headers["content-type"] == content_type
|
||||||
|
|
||||||
# test get 404
|
# test get 404
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
response = await client.get(
|
||||||
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
|
f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
|
||||||
|
)
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
@@ -61,18 +58,16 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_transcript_audio_download_head(
|
async def test_transcript_audio_download_head(
|
||||||
fake_transcript, url_suffix, content_type
|
fake_transcript, url_suffix, content_type, client
|
||||||
):
|
):
|
||||||
from reflector.app import app
|
response = await client.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
|
||||||
|
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
|
||||||
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == content_type
|
assert response.headers["content-type"] == content_type
|
||||||
|
|
||||||
# test head 404
|
# test head 404
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
response = await client.head(
|
||||||
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
|
f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
|
||||||
|
)
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
@@ -84,12 +79,9 @@ async def test_transcript_audio_download_head(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_transcript_audio_download_range(
|
async def test_transcript_audio_download_range(
|
||||||
fake_transcript, url_suffix, content_type
|
fake_transcript, url_suffix, content_type, client
|
||||||
):
|
):
|
||||||
from reflector.app import app
|
response = await client.get(
|
||||||
|
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
|
||||||
response = await ac.get(
|
|
||||||
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
|
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
|
||||||
headers={"range": "bytes=0-100"},
|
headers={"range": "bytes=0-100"},
|
||||||
)
|
)
|
||||||
@@ -107,12 +99,9 @@ async def test_transcript_audio_download_range(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_transcript_audio_download_range_with_seek(
|
async def test_transcript_audio_download_range_with_seek(
|
||||||
fake_transcript, url_suffix, content_type
|
fake_transcript, url_suffix, content_type, client
|
||||||
):
|
):
|
||||||
from reflector.app import app
|
response = await client.get(
|
||||||
|
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
|
||||||
response = await ac.get(
|
|
||||||
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
|
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
|
||||||
headers={"range": "bytes=100-"},
|
headers={"range": "bytes=100-"},
|
||||||
)
|
)
|
||||||
@@ -122,13 +111,10 @@ async def test_transcript_audio_download_range_with_seek(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_delete_with_audio(fake_transcript):
|
async def test_transcript_delete_with_audio(fake_transcript, client):
|
||||||
from reflector.app import app
|
response = await client.delete(f"/transcripts/{fake_transcript.id}")
|
||||||
|
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
|
||||||
response = await ac.delete(f"/transcripts/{fake_transcript.id}")
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{fake_transcript.id}")
|
response = await client.get(f"/transcripts/{fake_transcript.id}")
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|||||||
@@ -1,164 +1,151 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants():
|
async def test_transcript_participants(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["participants"] == []
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# create a participant
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
transcript_id = response.json()["id"]
|
||||||
assert response.status_code == 200
|
response = await client.post(
|
||||||
assert response.json()["participants"] == []
|
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["id"] is not None
|
||||||
|
assert response.json()["speaker"] is None
|
||||||
|
assert response.json()["name"] == "test"
|
||||||
|
|
||||||
# create a participant
|
# create another one with a speaker
|
||||||
transcript_id = response.json()["id"]
|
response = await client.post(
|
||||||
response = await ac.post(
|
f"/transcripts/{transcript_id}/participants",
|
||||||
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
|
json={"name": "test2", "speaker": 1},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["id"] is not None
|
assert response.json()["id"] is not None
|
||||||
assert response.json()["speaker"] is None
|
assert response.json()["speaker"] == 1
|
||||||
assert response.json()["name"] == "test"
|
assert response.json()["name"] == "test2"
|
||||||
|
|
||||||
# create another one with a speaker
|
# get all participants via transcript
|
||||||
response = await ac.post(
|
response = await client.get(f"/transcripts/{transcript_id}")
|
||||||
f"/transcripts/{transcript_id}/participants",
|
assert response.status_code == 200
|
||||||
json={"name": "test2", "speaker": 1},
|
assert len(response.json()["participants"]) == 2
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["id"] is not None
|
|
||||||
assert response.json()["speaker"] == 1
|
|
||||||
assert response.json()["name"] == "test2"
|
|
||||||
|
|
||||||
# get all participants via transcript
|
# get participants via participants endpoint
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
response = await client.get(f"/transcripts/{transcript_id}/participants")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()["participants"]) == 2
|
assert len(response.json()) == 2
|
||||||
|
|
||||||
# get participants via participants endpoint
|
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.json()) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants_same_speaker():
|
async def test_transcript_participants_same_speaker(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["participants"] == []
|
||||||
|
transcript_id = response.json()["id"]
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# create a participant
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
response = await client.post(
|
||||||
assert response.status_code == 200
|
f"/transcripts/{transcript_id}/participants",
|
||||||
assert response.json()["participants"] == []
|
json={"name": "test", "speaker": 1},
|
||||||
transcript_id = response.json()["id"]
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["speaker"] == 1
|
||||||
|
|
||||||
# create a participant
|
# create another one with the same speaker
|
||||||
response = await ac.post(
|
response = await client.post(
|
||||||
f"/transcripts/{transcript_id}/participants",
|
f"/transcripts/{transcript_id}/participants",
|
||||||
json={"name": "test", "speaker": 1},
|
json={"name": "test2", "speaker": 1},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 400
|
||||||
assert response.json()["speaker"] == 1
|
|
||||||
|
|
||||||
# create another one with the same speaker
|
|
||||||
response = await ac.post(
|
|
||||||
f"/transcripts/{transcript_id}/participants",
|
|
||||||
json={"name": "test2", "speaker": 1},
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants_update_name():
|
async def test_transcript_participants_update_name(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["participants"] == []
|
||||||
|
transcript_id = response.json()["id"]
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# create a participant
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
response = await client.post(
|
||||||
assert response.status_code == 200
|
f"/transcripts/{transcript_id}/participants",
|
||||||
assert response.json()["participants"] == []
|
json={"name": "test", "speaker": 1},
|
||||||
transcript_id = response.json()["id"]
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["speaker"] == 1
|
||||||
|
|
||||||
# create a participant
|
# update the participant
|
||||||
response = await ac.post(
|
participant_id = response.json()["id"]
|
||||||
f"/transcripts/{transcript_id}/participants",
|
response = await client.patch(
|
||||||
json={"name": "test", "speaker": 1},
|
f"/transcripts/{transcript_id}/participants/{participant_id}",
|
||||||
)
|
json={"name": "test2"},
|
||||||
assert response.status_code == 200
|
)
|
||||||
assert response.json()["speaker"] == 1
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test2"
|
||||||
|
|
||||||
# update the participant
|
# verify the participant was updated
|
||||||
participant_id = response.json()["id"]
|
response = await client.get(
|
||||||
response = await ac.patch(
|
f"/transcripts/{transcript_id}/participants/{participant_id}"
|
||||||
f"/transcripts/{transcript_id}/participants/{participant_id}",
|
)
|
||||||
json={"name": "test2"},
|
assert response.status_code == 200
|
||||||
)
|
assert response.json()["name"] == "test2"
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test2"
|
|
||||||
|
|
||||||
# verify the participant was updated
|
# verify the participant was updated in transcript
|
||||||
response = await ac.get(
|
response = await client.get(f"/transcripts/{transcript_id}")
|
||||||
f"/transcripts/{transcript_id}/participants/{participant_id}"
|
assert response.status_code == 200
|
||||||
)
|
assert len(response.json()["participants"]) == 1
|
||||||
assert response.status_code == 200
|
assert response.json()["participants"][0]["name"] == "test2"
|
||||||
assert response.json()["name"] == "test2"
|
|
||||||
|
|
||||||
# verify the participant was updated in transcript
|
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.json()["participants"]) == 1
|
|
||||||
assert response.json()["participants"][0]["name"] == "test2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants_update_speaker():
|
async def test_transcript_participants_update_speaker(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["participants"] == []
|
||||||
|
transcript_id = response.json()["id"]
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# create a participant
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
response = await client.post(
|
||||||
assert response.status_code == 200
|
f"/transcripts/{transcript_id}/participants",
|
||||||
assert response.json()["participants"] == []
|
json={"name": "test", "speaker": 1},
|
||||||
transcript_id = response.json()["id"]
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
participant1_id = response.json()["id"]
|
||||||
|
|
||||||
# create a participant
|
# create another participant
|
||||||
response = await ac.post(
|
response = await client.post(
|
||||||
f"/transcripts/{transcript_id}/participants",
|
f"/transcripts/{transcript_id}/participants",
|
||||||
json={"name": "test", "speaker": 1},
|
json={"name": "test2", "speaker": 2},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
participant1_id = response.json()["id"]
|
participant2_id = response.json()["id"]
|
||||||
|
|
||||||
# create another participant
|
# update the participant, refused as speaker is already taken
|
||||||
response = await ac.post(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/participants",
|
f"/transcripts/{transcript_id}/participants/{participant2_id}",
|
||||||
json={"name": "test2", "speaker": 2},
|
json={"speaker": 1},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 400
|
||||||
participant2_id = response.json()["id"]
|
|
||||||
|
|
||||||
# update the participant, refused as speaker is already taken
|
# delete the participant 1
|
||||||
response = await ac.patch(
|
response = await client.delete(
|
||||||
f"/transcripts/{transcript_id}/participants/{participant2_id}",
|
f"/transcripts/{transcript_id}/participants/{participant1_id}"
|
||||||
json={"speaker": 1},
|
)
|
||||||
)
|
assert response.status_code == 200
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
# delete the participant 1
|
# update the participant 2 again, should be accepted now
|
||||||
response = await ac.delete(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/participants/{participant1_id}"
|
f"/transcripts/{transcript_id}/participants/{participant2_id}",
|
||||||
)
|
json={"speaker": 1},
|
||||||
assert response.status_code == 200
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
# update the participant 2 again, should be accepted now
|
# ensure participant2 name is still there
|
||||||
response = await ac.patch(
|
response = await client.get(
|
||||||
f"/transcripts/{transcript_id}/participants/{participant2_id}",
|
f"/transcripts/{transcript_id}/participants/{participant2_id}"
|
||||||
json={"speaker": 1},
|
)
|
||||||
)
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
assert response.json()["name"] == "test2"
|
||||||
|
assert response.json()["speaker"] == 1
|
||||||
# ensure participant2 name is still there
|
|
||||||
response = await ac.get(
|
|
||||||
f"/transcripts/{transcript_id}/participants/{participant2_id}"
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test2"
|
|
||||||
assert response.json()["speaker"] == 1
|
|
||||||
|
|||||||
@@ -2,7 +2,25 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def app_lifespan():
|
||||||
|
from asgi_lifespan import LifespanManager
|
||||||
|
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
async with LifespanManager(app) as manager:
|
||||||
|
yield manager.app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(app_lifespan):
|
||||||
|
yield AsyncClient(
|
||||||
|
transport=ASGITransport(app=app_lifespan),
|
||||||
|
base_url="http://test/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@@ -11,23 +29,21 @@ from httpx import AsyncClient
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_process(
|
async def test_transcript_process(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
|
whisper_transcript,
|
||||||
dummy_llm,
|
dummy_llm,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
dummy_storage,
|
dummy_storage,
|
||||||
|
client,
|
||||||
):
|
):
|
||||||
from reflector.app import app
|
|
||||||
|
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
|
||||||
|
|
||||||
# create a transcript
|
# create a transcript
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "idle"
|
assert response.json()["status"] == "idle"
|
||||||
tid = response.json()["id"]
|
tid = response.json()["id"]
|
||||||
|
|
||||||
# upload mp3
|
# upload mp3
|
||||||
response = await ac.post(
|
response = await client.post(
|
||||||
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
|
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
|
||||||
files={
|
files={
|
||||||
"chunk": (
|
"chunk": (
|
||||||
@@ -45,7 +61,7 @@ async def test_transcript_process(
|
|||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
while (time.monotonic() - start_time) < timeout_seconds:
|
while (time.monotonic() - start_time) < timeout_seconds:
|
||||||
# fetch the transcript and check if it is ended
|
# fetch the transcript and check if it is ended
|
||||||
resp = await ac.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
if resp.json()["status"] in ("ended", "error"):
|
||||||
break
|
break
|
||||||
@@ -54,7 +70,7 @@ async def test_transcript_process(
|
|||||||
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
|
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
|
||||||
|
|
||||||
# restart the processing
|
# restart the processing
|
||||||
response = await ac.post(
|
response = await client.post(
|
||||||
f"/transcripts/{tid}/process",
|
f"/transcripts/{tid}/process",
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -65,7 +81,7 @@ async def test_transcript_process(
|
|||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
while (time.monotonic() - start_time) < timeout_seconds:
|
while (time.monotonic() - start_time) < timeout_seconds:
|
||||||
# fetch the transcript and check if it is ended
|
# fetch the transcript and check if it is ended
|
||||||
resp = await ac.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
if resp.json()["status"] in ("ended", "error"):
|
||||||
break
|
break
|
||||||
@@ -80,7 +96,7 @@ async def test_transcript_process(
|
|||||||
assert transcript["title"] == "Llm Title"
|
assert transcript["title"] == "Llm Title"
|
||||||
|
|
||||||
# check topics and transcript
|
# check topics and transcript
|
||||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
response = await client.get(f"/transcripts/{tid}/topics")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()) == 1
|
assert len(response.json()) == 1
|
||||||
assert "want to share" in response.json()[0]["transcript"]
|
assert "want to share" in response.json()[0]["transcript"]
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
from httpx_ws import aconnect_ws
|
from httpx_ws import aconnect_ws
|
||||||
from uvicorn import Config, Server
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
@@ -50,23 +49,69 @@ class ThreadedUvicorn:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
|
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
|
||||||
|
import threading
|
||||||
|
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
|
from reflector.db import get_database
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
DATA_DIR = settings.DATA_DIR
|
DATA_DIR = settings.DATA_DIR
|
||||||
settings.DATA_DIR = Path(tmpdir)
|
settings.DATA_DIR = Path(tmpdir)
|
||||||
|
|
||||||
# start server
|
# start server in a separate thread with its own event loop
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
port = 1255
|
port = 1255
|
||||||
config = Config(app=app, host=host, port=port)
|
server_started = threading.Event()
|
||||||
server = ThreadedUvicorn(config)
|
server_exception = None
|
||||||
await server.start()
|
server_instance = None
|
||||||
|
|
||||||
yield (server, host, port)
|
def run_server():
|
||||||
|
nonlocal server_exception, server_instance
|
||||||
|
try:
|
||||||
|
# Create a new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
config = Config(app=app, host=host, port=port, loop=loop)
|
||||||
|
server_instance = Server(config)
|
||||||
|
|
||||||
|
async def start_server():
|
||||||
|
# Initialize database connection in this event loop
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
await server_instance.serve()
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
|
# Signal that server is starting
|
||||||
|
server_started.set()
|
||||||
|
loop.run_until_complete(start_server())
|
||||||
|
except Exception as e:
|
||||||
|
server_exception = e
|
||||||
|
server_started.set()
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
server_started.wait(timeout=30)
|
||||||
|
if server_exception:
|
||||||
|
raise server_exception
|
||||||
|
|
||||||
|
# Wait a bit more for the server to be fully ready
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
yield server_instance, host, port
|
||||||
|
|
||||||
|
# Stop server
|
||||||
|
if server_instance:
|
||||||
|
server_instance.should_exit = True
|
||||||
|
server_thread.join(timeout=30)
|
||||||
|
|
||||||
server.stop()
|
|
||||||
settings.DATA_DIR = DATA_DIR
|
settings.DATA_DIR = DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
@@ -89,6 +134,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
dummy_storage,
|
dummy_storage,
|
||||||
fake_mp3_upload,
|
fake_mp3_upload,
|
||||||
appserver,
|
appserver,
|
||||||
|
client,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
# because of that, we need to start the server in a thread
|
# because of that, we need to start the server in a thread
|
||||||
@@ -97,8 +143,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
|
|
||||||
# create a transcript
|
# create a transcript
|
||||||
base_url = f"http://{host}:{port}/v1"
|
base_url = f"http://{host}:{port}/v1"
|
||||||
ac = AsyncClient(base_url=base_url)
|
response = await client.post("/transcripts", json={"name": "Test RTC"})
|
||||||
response = await ac.post("/transcripts", json={"name": "Test RTC"})
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
tid = response.json()["id"]
|
tid = response.json()["id"]
|
||||||
|
|
||||||
@@ -143,11 +188,11 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
|
|
||||||
url = f"{base_url}/transcripts/{tid}/record/webrtc"
|
url = f"{base_url}/transcripts/{tid}/record/webrtc"
|
||||||
path = Path(__file__).parent / "records" / "test_short.wav"
|
path = Path(__file__).parent / "records" / "test_short.wav"
|
||||||
client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
||||||
await client.start()
|
await stream_client.start()
|
||||||
|
|
||||||
timeout = 20
|
timeout = 120
|
||||||
while not client.is_ended():
|
while not stream_client.is_ended():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
timeout -= 1
|
timeout -= 1
|
||||||
if timeout < 0:
|
if timeout < 0:
|
||||||
@@ -155,14 +200,14 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
|
|
||||||
# XXX aiortc is long to close the connection
|
# XXX aiortc is long to close the connection
|
||||||
# instead of waiting a long time, we just send a STOP
|
# instead of waiting a long time, we just send a STOP
|
||||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||||
await client.stop()
|
await stream_client.stop()
|
||||||
|
|
||||||
# wait the processing to finish
|
# wait the processing to finish
|
||||||
timeout = 20
|
timeout = 120
|
||||||
while True:
|
while True:
|
||||||
# fetch the transcript and check if it is ended
|
# fetch the transcript and check if it is ended
|
||||||
resp = await ac.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
if resp.json()["status"] in ("ended", "error"):
|
||||||
break
|
break
|
||||||
@@ -215,7 +260,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
ev = events[eventnames.index("WAVEFORM")]
|
ev = events[eventnames.index("WAVEFORM")]
|
||||||
assert isinstance(ev["data"]["waveform"], list)
|
assert isinstance(ev["data"]["waveform"], list)
|
||||||
assert len(ev["data"]["waveform"]) >= 250
|
assert len(ev["data"]["waveform"]) >= 250
|
||||||
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
|
waveform_resp = await client.get(f"/transcripts/{tid}/audio/waveform")
|
||||||
assert waveform_resp.status_code == 200
|
assert waveform_resp.status_code == 200
|
||||||
assert waveform_resp.headers["content-type"] == "application/json"
|
assert waveform_resp.headers["content-type"] == "application/json"
|
||||||
assert isinstance(waveform_resp.json()["data"], list)
|
assert isinstance(waveform_resp.json()["data"], list)
|
||||||
@@ -235,7 +280,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
assert "DURATION" in eventnames
|
assert "DURATION" in eventnames
|
||||||
|
|
||||||
# check that audio/mp3 is available
|
# check that audio/mp3 is available
|
||||||
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
audio_resp = await client.get(f"/transcripts/{tid}/audio/mp3")
|
||||||
assert audio_resp.status_code == 200
|
assert audio_resp.status_code == 200
|
||||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
||||||
|
|
||||||
@@ -254,6 +299,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
dummy_storage,
|
dummy_storage,
|
||||||
fake_mp3_upload,
|
fake_mp3_upload,
|
||||||
appserver,
|
appserver,
|
||||||
|
client,
|
||||||
):
|
):
|
||||||
# goal: start the server, exchange RTC, receive websocket events
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
# because of that, we need to start the server in a thread
|
# because of that, we need to start the server in a thread
|
||||||
@@ -263,8 +309,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
|
|
||||||
# create a transcript
|
# create a transcript
|
||||||
base_url = f"http://{host}:{port}/v1"
|
base_url = f"http://{host}:{port}/v1"
|
||||||
ac = AsyncClient(base_url=base_url)
|
response = await client.post(
|
||||||
response = await ac.post(
|
|
||||||
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
|
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -311,11 +356,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
|
|
||||||
url = f"{base_url}/transcripts/{tid}/record/webrtc"
|
url = f"{base_url}/transcripts/{tid}/record/webrtc"
|
||||||
path = Path(__file__).parent / "records" / "test_short.wav"
|
path = Path(__file__).parent / "records" / "test_short.wav"
|
||||||
client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
||||||
await client.start()
|
await stream_client.start()
|
||||||
|
|
||||||
timeout = 20
|
timeout = 120
|
||||||
while not client.is_ended():
|
while not stream_client.is_ended():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
timeout -= 1
|
timeout -= 1
|
||||||
if timeout < 0:
|
if timeout < 0:
|
||||||
@@ -323,18 +368,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
|
|
||||||
# XXX aiortc is long to close the connection
|
# XXX aiortc is long to close the connection
|
||||||
# instead of waiting a long time, we just send a STOP
|
# instead of waiting a long time, we just send a STOP
|
||||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||||
|
|
||||||
# wait the processing to finish
|
# wait the processing to finish
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
await client.stop()
|
await stream_client.stop()
|
||||||
|
|
||||||
# wait the processing to finish
|
# wait the processing to finish
|
||||||
timeout = 20
|
timeout = 120
|
||||||
while True:
|
while True:
|
||||||
# fetch the transcript and check if it is ended
|
# fetch the transcript and check if it is ended
|
||||||
resp = await ac.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] == "ended":
|
if resp.json()["status"] == "ended":
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -1,401 +1,390 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_reassign_speaker(fake_transcript_with_topics):
|
async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
|
||||||
from reflector.app import app
|
|
||||||
|
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# check the transcript exists
|
||||||
# check the transcript exists
|
response = await client.get(f"/transcripts/{transcript_id}")
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
# check initial topics of the transcript
|
# check initial topics of the transcript
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 0
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
assert topics[0]["words"][1]["speaker"] == 0
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
# check through segments
|
# check through segments
|
||||||
assert len(topics[0]["segments"]) == 1
|
assert len(topics[0]["segments"]) == 1
|
||||||
assert topics[0]["segments"][0]["speaker"] == 0
|
assert topics[0]["segments"][0]["speaker"] == 0
|
||||||
assert len(topics[1]["segments"]) == 1
|
assert len(topics[1]["segments"]) == 1
|
||||||
assert topics[1]["segments"][0]["speaker"] == 0
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
# reassign speaker
|
# reassign speaker
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"speaker": 1,
|
"speaker": 1,
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 1,
|
"timestamp_to": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 1
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
assert topics[0]["words"][1]["speaker"] == 1
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
# check segments
|
# check segments
|
||||||
assert len(topics[0]["segments"]) == 1
|
assert len(topics[0]["segments"]) == 1
|
||||||
assert topics[0]["segments"][0]["speaker"] == 1
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
assert len(topics[1]["segments"]) == 1
|
assert len(topics[1]["segments"]) == 1
|
||||||
assert topics[1]["segments"][0]["speaker"] == 0
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
# reassign speaker, middle of 2 topics
|
# reassign speaker, middle of 2 topics
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"speaker": 2,
|
"speaker": 2,
|
||||||
"timestamp_from": 1,
|
"timestamp_from": 1,
|
||||||
"timestamp_to": 2.5,
|
"timestamp_to": 2.5,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 1
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
assert topics[0]["words"][1]["speaker"] == 2
|
assert topics[0]["words"][1]["speaker"] == 2
|
||||||
assert topics[1]["words"][0]["speaker"] == 2
|
assert topics[1]["words"][0]["speaker"] == 2
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
# check segments
|
# check segments
|
||||||
assert len(topics[0]["segments"]) == 2
|
assert len(topics[0]["segments"]) == 2
|
||||||
assert topics[0]["segments"][0]["speaker"] == 1
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
assert topics[0]["segments"][1]["speaker"] == 2
|
assert topics[0]["segments"][1]["speaker"] == 2
|
||||||
assert len(topics[1]["segments"]) == 2
|
assert len(topics[1]["segments"]) == 2
|
||||||
assert topics[1]["segments"][0]["speaker"] == 2
|
assert topics[1]["segments"][0]["speaker"] == 2
|
||||||
assert topics[1]["segments"][1]["speaker"] == 0
|
assert topics[1]["segments"][1]["speaker"] == 0
|
||||||
|
|
||||||
# reassign speaker, everything
|
# reassign speaker, everything
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"speaker": 4,
|
"speaker": 4,
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 100,
|
"timestamp_to": 100,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 4
|
assert topics[0]["words"][0]["speaker"] == 4
|
||||||
assert topics[0]["words"][1]["speaker"] == 4
|
assert topics[0]["words"][1]["speaker"] == 4
|
||||||
assert topics[1]["words"][0]["speaker"] == 4
|
assert topics[1]["words"][0]["speaker"] == 4
|
||||||
assert topics[1]["words"][1]["speaker"] == 4
|
assert topics[1]["words"][1]["speaker"] == 4
|
||||||
# check segments
|
# check segments
|
||||||
assert len(topics[0]["segments"]) == 1
|
assert len(topics[0]["segments"]) == 1
|
||||||
assert topics[0]["segments"][0]["speaker"] == 4
|
assert topics[0]["segments"][0]["speaker"] == 4
|
||||||
assert len(topics[1]["segments"]) == 1
|
assert len(topics[1]["segments"]) == 1
|
||||||
assert topics[1]["segments"][0]["speaker"] == 4
|
assert topics[1]["segments"][0]["speaker"] == 4
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_merge_speaker(fake_transcript_with_topics):
|
async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
|
||||||
from reflector.app import app
|
|
||||||
|
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# check the transcript exists
|
||||||
# check the transcript exists
|
response = await client.get(f"/transcripts/{transcript_id}")
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
# check initial topics of the transcript
|
# check initial topics of the transcript
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 0
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
assert topics[0]["words"][1]["speaker"] == 0
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
|
||||||
# reassign speaker
|
# reassign speaker
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"speaker": 1,
|
"speaker": 1,
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 1,
|
"timestamp_to": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 1
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
assert topics[0]["words"][1]["speaker"] == 1
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
|
||||||
# merge speakers
|
# merge speakers
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/merge",
|
f"/transcripts/{transcript_id}/speaker/merge",
|
||||||
json={
|
json={
|
||||||
"speaker_from": 1,
|
"speaker_from": 1,
|
||||||
"speaker_to": 0,
|
"speaker_to": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 0
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
assert topics[0]["words"][1]["speaker"] == 0
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
|
async def test_transcript_reassign_with_participant(
|
||||||
from reflector.app import app
|
fake_transcript_with_topics, client
|
||||||
|
):
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# check the transcript exists
|
||||||
# check the transcript exists
|
response = await client.get(f"/transcripts/{transcript_id}")
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
transcript = response.json()
|
||||||
transcript = response.json()
|
assert len(transcript["participants"]) == 0
|
||||||
assert len(transcript["participants"]) == 0
|
|
||||||
|
|
||||||
# create 2 participants
|
# create 2 participants
|
||||||
response = await ac.post(
|
response = await client.post(
|
||||||
f"/transcripts/{transcript_id}/participants",
|
f"/transcripts/{transcript_id}/participants",
|
||||||
json={
|
json={
|
||||||
"name": "Participant 1",
|
"name": "Participant 1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
participant1_id = response.json()["id"]
|
participant1_id = response.json()["id"]
|
||||||
|
|
||||||
response = await ac.post(
|
response = await client.post(
|
||||||
f"/transcripts/{transcript_id}/participants",
|
f"/transcripts/{transcript_id}/participants",
|
||||||
json={
|
json={
|
||||||
"name": "Participant 2",
|
"name": "Participant 2",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
participant2_id = response.json()["id"]
|
participant2_id = response.json()["id"]
|
||||||
|
|
||||||
# check participants speakers
|
# check participants speakers
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
response = await client.get(f"/transcripts/{transcript_id}/participants")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
participants = response.json()
|
participants = response.json()
|
||||||
assert len(participants) == 2
|
assert len(participants) == 2
|
||||||
assert participants[0]["name"] == "Participant 1"
|
assert participants[0]["name"] == "Participant 1"
|
||||||
assert participants[0]["speaker"] is None
|
assert participants[0]["speaker"] is None
|
||||||
assert participants[1]["name"] == "Participant 2"
|
assert participants[1]["name"] == "Participant 2"
|
||||||
assert participants[1]["speaker"] is None
|
assert participants[1]["speaker"] is None
|
||||||
|
|
||||||
# check initial topics of the transcript
|
# check initial topics of the transcript
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 0
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
assert topics[0]["words"][1]["speaker"] == 0
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
# check through segments
|
# check through segments
|
||||||
assert len(topics[0]["segments"]) == 1
|
assert len(topics[0]["segments"]) == 1
|
||||||
assert topics[0]["segments"][0]["speaker"] == 0
|
assert topics[0]["segments"][0]["speaker"] == 0
|
||||||
assert len(topics[1]["segments"]) == 1
|
assert len(topics[1]["segments"]) == 1
|
||||||
assert topics[1]["segments"][0]["speaker"] == 0
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
# reassign speaker from a participant
|
# reassign speaker from a participant
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"participant": participant1_id,
|
"participant": participant1_id,
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 1,
|
"timestamp_to": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check participants if speaker has been assigned
|
# check participants if speaker has been assigned
|
||||||
# first participant should have 1, because it's not used yet.
|
# first participant should have 1, because it's not used yet.
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
response = await client.get(f"/transcripts/{transcript_id}/participants")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
participants = response.json()
|
participants = response.json()
|
||||||
assert len(participants) == 2
|
assert len(participants) == 2
|
||||||
assert participants[0]["name"] == "Participant 1"
|
assert participants[0]["name"] == "Participant 1"
|
||||||
assert participants[0]["speaker"] == 1
|
assert participants[0]["speaker"] == 1
|
||||||
assert participants[1]["name"] == "Participant 2"
|
assert participants[1]["name"] == "Participant 2"
|
||||||
assert participants[1]["speaker"] is None
|
assert participants[1]["speaker"] is None
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 1
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
assert topics[0]["words"][1]["speaker"] == 1
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
# check segments
|
# check segments
|
||||||
assert len(topics[0]["segments"]) == 1
|
assert len(topics[0]["segments"]) == 1
|
||||||
assert topics[0]["segments"][0]["speaker"] == 1
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
assert len(topics[1]["segments"]) == 1
|
assert len(topics[1]["segments"]) == 1
|
||||||
assert topics[1]["segments"][0]["speaker"] == 0
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
# reassign participant, middle of 2 topics
|
# reassign participant, middle of 2 topics
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"participant": participant2_id,
|
"participant": participant2_id,
|
||||||
"timestamp_from": 1,
|
"timestamp_from": 1,
|
||||||
"timestamp_to": 2.5,
|
"timestamp_to": 2.5,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check participants if speaker has been assigned
|
# check participants if speaker has been assigned
|
||||||
# first participant should have 1, because it's not used yet.
|
# first participant should have 1, because it's not used yet.
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
response = await client.get(f"/transcripts/{transcript_id}/participants")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
participants = response.json()
|
participants = response.json()
|
||||||
assert len(participants) == 2
|
assert len(participants) == 2
|
||||||
assert participants[0]["name"] == "Participant 1"
|
assert participants[0]["name"] == "Participant 1"
|
||||||
assert participants[0]["speaker"] == 1
|
assert participants[0]["speaker"] == 1
|
||||||
assert participants[1]["name"] == "Participant 2"
|
assert participants[1]["name"] == "Participant 2"
|
||||||
assert participants[1]["speaker"] == 2
|
assert participants[1]["speaker"] == 2
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 1
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
assert topics[0]["words"][1]["speaker"] == 2
|
assert topics[0]["words"][1]["speaker"] == 2
|
||||||
assert topics[1]["words"][0]["speaker"] == 2
|
assert topics[1]["words"][0]["speaker"] == 2
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
# check segments
|
# check segments
|
||||||
assert len(topics[0]["segments"]) == 2
|
assert len(topics[0]["segments"]) == 2
|
||||||
assert topics[0]["segments"][0]["speaker"] == 1
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
assert topics[0]["segments"][1]["speaker"] == 2
|
assert topics[0]["segments"][1]["speaker"] == 2
|
||||||
assert len(topics[1]["segments"]) == 2
|
assert len(topics[1]["segments"]) == 2
|
||||||
assert topics[1]["segments"][0]["speaker"] == 2
|
assert topics[1]["segments"][0]["speaker"] == 2
|
||||||
assert topics[1]["segments"][1]["speaker"] == 0
|
assert topics[1]["segments"][1]["speaker"] == 0
|
||||||
|
|
||||||
# reassign speaker, everything
|
# reassign speaker, everything
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"participant": participant1_id,
|
"participant": participant1_id,
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 100,
|
"timestamp_to": 100,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# check topics again
|
# check topics again
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
topics = response.json()
|
topics = response.json()
|
||||||
assert len(topics) == 2
|
assert len(topics) == 2
|
||||||
|
|
||||||
# check through words
|
# check through words
|
||||||
assert topics[0]["words"][0]["speaker"] == 1
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
assert topics[0]["words"][1]["speaker"] == 1
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
assert topics[1]["words"][0]["speaker"] == 1
|
assert topics[1]["words"][0]["speaker"] == 1
|
||||||
assert topics[1]["words"][1]["speaker"] == 1
|
assert topics[1]["words"][1]["speaker"] == 1
|
||||||
# check segments
|
# check segments
|
||||||
assert len(topics[0]["segments"]) == 1
|
assert len(topics[0]["segments"]) == 1
|
||||||
assert topics[0]["segments"][0]["speaker"] == 1
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
assert len(topics[1]["segments"]) == 1
|
assert len(topics[1]["segments"]) == 1
|
||||||
assert topics[1]["segments"][0]["speaker"] == 1
|
assert topics[1]["segments"][0]["speaker"] == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
|
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics, client):
|
||||||
from reflector.app import app
|
|
||||||
|
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# check the transcript exists
|
||||||
# check the transcript exists
|
response = await client.get(f"/transcripts/{transcript_id}")
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}")
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
transcript = response.json()
|
||||||
transcript = response.json()
|
assert len(transcript["participants"]) == 0
|
||||||
assert len(transcript["participants"]) == 0
|
|
||||||
|
|
||||||
# try reassign without any participant_id or speaker
|
# try reassign without any participant_id or speaker
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 1,
|
"timestamp_to": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
# try reassing with both participant_id and speaker
|
# try reassing with both participant_id and speaker
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"participant": "123",
|
"participant": "123",
|
||||||
"speaker": 1,
|
"speaker": 1,
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 1,
|
"timestamp_to": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
# try reassing with non-existing participant_id
|
# try reassing with non-existing participant_id
|
||||||
response = await ac.patch(
|
response = await client.patch(
|
||||||
f"/transcripts/{transcript_id}/speaker/assign",
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
json={
|
json={
|
||||||
"participant": "123",
|
"participant": "123",
|
||||||
"timestamp_from": 0,
|
"timestamp_from": 0,
|
||||||
"timestamp_to": 1,
|
"timestamp_to": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|||||||
@@ -1,26 +1,22 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_topics(fake_transcript_with_topics):
|
async def test_transcript_topics(fake_transcript_with_topics, client):
|
||||||
from reflector.app import app
|
|
||||||
|
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
# check the transcript exists
|
||||||
# check the transcript exists
|
response = await client.get(f"/transcripts/{transcript_id}/topics")
|
||||||
response = await ac.get(f"/transcripts/{transcript_id}/topics")
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
assert len(response.json()) == 2
|
||||||
assert len(response.json()) == 2
|
topic_id = response.json()[0]["id"]
|
||||||
topic_id = response.json()[0]["id"]
|
|
||||||
|
|
||||||
# get words per speakers
|
# get words per speakers
|
||||||
response = await ac.get(
|
response = await client.get(
|
||||||
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
|
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data["words_per_speaker"]) == 1
|
assert len(data["words_per_speaker"]) == 1
|
||||||
assert data["words_per_speaker"][0]["speaker"] == 0
|
assert data["words_per_speaker"][0]["speaker"] == 0
|
||||||
assert len(data["words_per_speaker"][0]["words"]) == 2
|
assert len(data["words_per_speaker"][0]["words"]) == 2
|
||||||
|
|||||||
@@ -1,63 +1,53 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_create_default_translation():
|
async def test_transcript_create_default_translation(client):
|
||||||
from reflector.app import app
|
response = await client.post("/transcripts", json={"name": "test en"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test en"
|
||||||
|
assert response.json()["source_language"] == "en"
|
||||||
|
assert response.json()["target_language"] == "en"
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
response = await ac.post("/transcripts", json={"name": "test en"})
|
assert response.status_code == 200
|
||||||
assert response.status_code == 200
|
assert response.json()["name"] == "test en"
|
||||||
assert response.json()["name"] == "test en"
|
assert response.json()["source_language"] == "en"
|
||||||
assert response.json()["source_language"] == "en"
|
assert response.json()["target_language"] == "en"
|
||||||
assert response.json()["target_language"] == "en"
|
|
||||||
tid = response.json()["id"]
|
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test en"
|
|
||||||
assert response.json()["source_language"] == "en"
|
|
||||||
assert response.json()["target_language"] == "en"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_create_en_fr_translation():
|
async def test_transcript_create_en_fr_translation(client):
|
||||||
from reflector.app import app
|
response = await client.post(
|
||||||
|
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test en/fr"
|
||||||
|
assert response.json()["source_language"] == "en"
|
||||||
|
assert response.json()["target_language"] == "fr"
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
response = await ac.post(
|
assert response.status_code == 200
|
||||||
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
|
assert response.json()["name"] == "test en/fr"
|
||||||
)
|
assert response.json()["source_language"] == "en"
|
||||||
assert response.status_code == 200
|
assert response.json()["target_language"] == "fr"
|
||||||
assert response.json()["name"] == "test en/fr"
|
|
||||||
assert response.json()["source_language"] == "en"
|
|
||||||
assert response.json()["target_language"] == "fr"
|
|
||||||
tid = response.json()["id"]
|
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test en/fr"
|
|
||||||
assert response.json()["source_language"] == "en"
|
|
||||||
assert response.json()["target_language"] == "fr"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_create_fr_en_translation():
|
async def test_transcript_create_fr_en_translation(client):
|
||||||
from reflector.app import app
|
response = await client.post(
|
||||||
|
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test fr/en"
|
||||||
|
assert response.json()["source_language"] == "fr"
|
||||||
|
assert response.json()["target_language"] == "en"
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
response = await client.get(f"/transcripts/{tid}")
|
||||||
response = await ac.post(
|
assert response.status_code == 200
|
||||||
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
|
assert response.json()["name"] == "test fr/en"
|
||||||
)
|
assert response.json()["source_language"] == "fr"
|
||||||
assert response.status_code == 200
|
assert response.json()["target_language"] == "en"
|
||||||
assert response.json()["name"] == "test fr/en"
|
|
||||||
assert response.json()["source_language"] == "fr"
|
|
||||||
assert response.json()["target_language"] == "en"
|
|
||||||
tid = response.json()["id"]
|
|
||||||
|
|
||||||
response = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["name"] == "test fr/en"
|
|
||||||
assert response.json()["source_language"] == "fr"
|
|
||||||
assert response.json()["target_language"] == "en"
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("setup_database")
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@@ -15,19 +14,16 @@ async def test_transcript_upload_file(
|
|||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
dummy_storage,
|
dummy_storage,
|
||||||
|
client,
|
||||||
):
|
):
|
||||||
from reflector.app import app
|
|
||||||
|
|
||||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
|
||||||
|
|
||||||
# create a transcript
|
# create a transcript
|
||||||
response = await ac.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "idle"
|
assert response.json()["status"] == "idle"
|
||||||
tid = response.json()["id"]
|
tid = response.json()["id"]
|
||||||
|
|
||||||
# upload mp3
|
# upload mp3
|
||||||
response = await ac.post(
|
response = await client.post(
|
||||||
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
|
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
|
||||||
files={
|
files={
|
||||||
"chunk": (
|
"chunk": (
|
||||||
@@ -45,7 +41,7 @@ async def test_transcript_upload_file(
|
|||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
while (time.monotonic() - start_time) < timeout_seconds:
|
while (time.monotonic() - start_time) < timeout_seconds:
|
||||||
# fetch the transcript and check if it is ended
|
# fetch the transcript and check if it is ended
|
||||||
resp = await ac.get(f"/transcripts/{tid}")
|
resp = await client.get(f"/transcripts/{tid}")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
if resp.json()["status"] in ("ended", "error"):
|
if resp.json()["status"] in ("ended", "error"):
|
||||||
break
|
break
|
||||||
@@ -60,7 +56,7 @@ async def test_transcript_upload_file(
|
|||||||
assert transcript["title"] == "Llm Title"
|
assert transcript["title"] == "Llm Title"
|
||||||
|
|
||||||
# check topics and transcript
|
# check topics and transcript
|
||||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
response = await client.get(f"/transcripts/{tid}/topics")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()) == 1
|
assert len(response.json()) == 1
|
||||||
assert "want to share" in response.json()[0]["transcript"]
|
assert "want to share" in response.json()[0]["transcript"]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reflector.db import database
|
from reflector.db import get_database
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
SourceKind,
|
SourceKind,
|
||||||
TranscriptController,
|
TranscriptController,
|
||||||
@@ -26,7 +26,7 @@ class TestWebVTTAutoUpdate:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await database.fetch_one(
|
result = await get_database().fetch_one(
|
||||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ class TestWebVTTAutoUpdate:
|
|||||||
|
|
||||||
await controller.upsert_topic(transcript, topic)
|
await controller.upsert_topic(transcript, topic)
|
||||||
|
|
||||||
result = await database.fetch_one(
|
result = await get_database().fetch_one(
|
||||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ class TestWebVTTAutoUpdate:
|
|||||||
await controller.update(transcript, {"topics": topics_data})
|
await controller.update(transcript, {"topics": topics_data})
|
||||||
|
|
||||||
# Fetch from DB
|
# Fetch from DB
|
||||||
result = await database.fetch_one(
|
result = await get_database().fetch_one(
|
||||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ class TestWebVTTAutoUpdate:
|
|||||||
await controller.update(transcript, values)
|
await controller.update(transcript, values)
|
||||||
|
|
||||||
# Fetch from DB
|
# Fetch from DB
|
||||||
result = await database.fetch_one(
|
result = await get_database().fetch_one(
|
||||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -216,7 +216,7 @@ class TestWebVTTAutoUpdate:
|
|||||||
await controller.update(transcript, values)
|
await controller.update(transcript, values)
|
||||||
|
|
||||||
# Fetch from DB
|
# Fetch from DB
|
||||||
result = await database.fetch_one(
|
result = await get_database().fetch_one(
|
||||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
2739
server/uv.lock
generated
2739
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user