feat: postgresql migration and removal of sqlite in pytest (#546)

* feat: remove support of sqlite, 100% postgres

* fix: more migration and make datetime timezone aware in postgres

* fix: change how database is get, and use contextvar to have difference instance between different loops

* test: properly use client fixture that handle lifetime/database connection

* fix: add missing client fixture parameters to test functions

This commit fixes NameError issues where test functions were trying to use
the 'client' fixture but didn't have it as a parameter. The changes include:

1. Added 'client' parameter to test functions in:
   - test_transcripts_audio_download.py (6 functions including fixture)
   - test_transcripts_speaker.py (3 functions)
   - test_transcripts_upload.py (1 function)
   - test_transcripts_rtc_ws.py (2 functions + appserver fixture)

2. Resolved naming conflicts in test_transcripts_rtc_ws.py where both HTTP
   client and StreamClient were using variable name 'client'. StreamClient
   instances are now named 'stream_client' to avoid conflicts.

3. Added missing 'from reflector.app import app' import in rtc_ws tests.

Background: Previously implemented contextvars solution with get_database()
function resolves asyncio event loop conflicts in Celery tasks. The global
client fixture was also created to replace manual AsyncClient instances,
ensuring proper FastAPI application lifecycle management and database
connections during tests.

All tests now pass except for 2 pre-existing RTC WebSocket test failures
related to asyncpg connection issues unrelated to these fixes.

* fix: ensure task are correctly closed

* fix: make separate event loop for the live server

* fix: make default settings pointing at postgres

* build: remove pytest-docker deps out of dev, just tests group
This commit is contained in:
2025-08-14 11:40:52 -06:00
committed by GitHub
parent 6fb5cb21c2
commit 9eab952c63
41 changed files with 2570 additions and 2287 deletions

View File

@@ -17,10 +17,40 @@ on:
jobs:
test-migrations:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: reflector
POSTGRES_PASSWORD: reflector
POSTGRES_DB: reflector
ports:
- 5432:5432
options: >-
--health-cmd pg_isready -h 127.0.0.1 -p 5432
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
DATABASE_URL: postgresql://reflector:reflector@localhost:5432/reflector
steps:
- uses: actions/checkout@v4
- name: Install PostgreSQL client
run: sudo apt-get update && sudo apt-get install -y postgresql-client | cat
- name: Wait for Postgres
run: |
for i in {1..30}; do
if pg_isready -h localhost -p 5432; then
echo "Postgres is ready"
break
fi
echo "Waiting for Postgres... ($i)" && sleep 1
done
- name: Install uv
uses: astral-sh/setup-uv@v3
with:

View File

@@ -32,7 +32,7 @@ def upgrade() -> None:
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("room_id", sa.String(), nullable=True),
sa.Column(
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False
"is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
sa.Column(
@@ -53,12 +53,15 @@ def upgrade() -> None:
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column(
"zulip_auto_post", sa.Boolean(), server_default=sa.text("0"), nullable=False
"zulip_auto_post",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column("zulip_stream", sa.String(), nullable=True),
sa.Column("zulip_topic", sa.String(), nullable=True),
sa.Column(
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False
"is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
sa.Column(

View File

@@ -20,11 +20,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sourcekind_enum = sa.Enum("room", "live", "file", name="sourcekind")
sourcekind_enum.create(op.get_bind())
op.add_column(
"transcript",
sa.Column(
"source_kind",
sa.Enum("ROOM", "LIVE", "FILE", name="sourcekind"),
sourcekind_enum,
nullable=True,
),
)
@@ -43,6 +46,8 @@ def upgrade() -> None:
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "source_kind")
sourcekind_enum = sa.Enum(name="sourcekind")
sourcekind_enum.drop(op.get_bind())
# ### end Alembic commands ###

View File

@@ -22,7 +22,7 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.execute(
"UPDATE transcript SET events = "
'REPLACE(events, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\');'
'REPLACE(events::text, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\')::json;'
)
op.alter_column("transcript", "summary", new_column_name="long_summary")
op.add_column("transcript", sa.Column("title", sa.String(), nullable=True))
@@ -34,7 +34,7 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.execute(
"UPDATE transcript SET events = "
'REPLACE(events, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\');'
'REPLACE(events::text, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\')::json;'
)
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column("long_summary", nullable=True, new_column_name="summary")

View File

@@ -0,0 +1,121 @@
"""datetime timezone
Revision ID: 9f5c78d352d6
Revises: 8120ebc75366
Create Date: 2025-08-13 19:18:27.113593
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "9f5c78d352d6"
down_revision: Union[str, None] = "8120ebc75366"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.alter_column(
"start_date",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
batch_op.alter_column(
"end_date",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
batch_op.alter_column(
"consent_timestamp",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.alter_column(
"recorded_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.alter_column(
"recorded_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
batch_op.alter_column(
"consent_timestamp",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.alter_column(
"end_date",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
batch_op.alter_column(
"start_date",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
# ### end Alembic commands ###

View File

@@ -25,7 +25,7 @@ def upgrade() -> None:
sa.Column(
"is_shared",
sa.Boolean(),
server_default=sa.text("0"),
server_default=sa.text("false"),
nullable=False,
),
)

View File

@@ -23,7 +23,10 @@ def upgrade() -> None:
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"is_active", sa.Boolean(), server_default=sa.text("1"), nullable=False
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
)
)

View File

@@ -23,7 +23,7 @@ def upgrade() -> None:
op.add_column(
"transcript",
sa.Column(
"reviewed", sa.Boolean(), server_default=sa.text("0"), nullable=False
"reviewed", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
)
# ### end Alembic commands ###

View File

@@ -57,6 +57,8 @@ tests = [
"httpx-ws>=0.4.1",
"pytest-httpx>=0.23.1",
"pytest-celery>=0.0.0",
"pytest-docker>=3.2.3",
"asgi-lifespan>=2.1.0",
]
aws = ["aioboto3>=11.2.0"]
evaluation = [
@@ -86,7 +88,7 @@ source = ["reflector"]
[tool.pytest_env]
ENVIRONMENT = "pytest"
DATABASE_URL = "sqlite:///test.sqlite"
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
[tool.pytest.ini_options]
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"

View File

@@ -1,12 +1,28 @@
import contextvars
from typing import Optional
import databases
import sqlalchemy
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData()
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
contextvars.ContextVar("database", default=None)
)
def get_database() -> databases.Database:
"""Get database instance for current asyncio context"""
db = _database_context.get()
if db is None:
db = databases.Database(settings.DATABASE_URL)
_database_context.set(db)
return db
# import models
import reflector.db.meetings # noqa
import reflector.db.recordings # noqa
@@ -14,16 +30,18 @@ import reflector.db.rooms # noqa
import reflector.db.transcripts # noqa
kwargs = {}
if "sqlite" in settings.DATABASE_URL:
kwargs["connect_args"] = {"check_same_thread": False}
if "postgres" not in settings.DATABASE_URL:
raise Exception("Only postgres database is supported in reflector")
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
@subscribers_startup.append
async def database_connect(_):
database = get_database()
await database.connect()
@subscribers_shutdown.append
async def database_disconnect(_):
database = get_database()
await database.disconnect()

View File

@@ -5,7 +5,7 @@ import sqlalchemy as sa
from fastapi import HTTPException
from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.db.rooms import Room
from reflector.utils import generate_uuid4
@@ -16,8 +16,8 @@ meetings = sa.Table(
sa.Column("room_name", sa.String),
sa.Column("room_url", sa.String),
sa.Column("host_room_url", sa.String),
sa.Column("start_date", sa.DateTime),
sa.Column("end_date", sa.DateTime),
sa.Column("start_date", sa.DateTime(timezone=True)),
sa.Column("end_date", sa.DateTime(timezone=True)),
sa.Column("user_id", sa.String),
sa.Column("room_id", sa.String),
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
@@ -42,6 +42,12 @@ meetings = sa.Table(
server_default=sa.true(),
),
sa.Index("idx_meeting_room_id", "room_id"),
sa.Index(
"idx_one_active_meeting_per_room",
"room_id",
unique=True,
postgresql_where=sa.text("is_active = true"),
),
)
meeting_consent = sa.Table(
@@ -51,7 +57,7 @@ meeting_consent = sa.Table(
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
sa.Column("user_id", sa.String),
sa.Column("consent_given", sa.Boolean, nullable=False),
sa.Column("consent_timestamp", sa.DateTime, nullable=False),
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
)
@@ -111,7 +117,7 @@ class MeetingController:
recording_trigger=room.recording_trigger,
)
query = meetings.insert().values(**meeting.model_dump())
await database.execute(query)
await get_database().execute(query)
return meeting
async def get_all_active(self) -> list[Meeting]:
@@ -119,7 +125,7 @@ class MeetingController:
Get active meetings.
"""
query = meetings.select().where(meetings.c.is_active)
return await database.fetch_all(query)
return await get_database().fetch_all(query)
async def get_by_room_name(
self,
@@ -129,7 +135,7 @@ class MeetingController:
Get a meeting by room name.
"""
query = meetings.select().where(meetings.c.room_name == room_name)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
@@ -151,7 +157,7 @@ class MeetingController:
)
.order_by(end_date.desc())
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
@@ -162,7 +168,7 @@ class MeetingController:
Get a meeting by id
"""
query = meetings.select().where(meetings.c.id == meeting_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting(**result)
@@ -174,7 +180,7 @@ class MeetingController:
If not found, it will raise a 404 error.
"""
query = meetings.select().where(meetings.c.id == meeting_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Meeting not found")
@@ -186,7 +192,7 @@ class MeetingController:
async def update_meeting(self, meeting_id: str, **kwargs):
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
await database.execute(query)
await get_database().execute(query)
class MeetingConsentController:
@@ -194,7 +200,7 @@ class MeetingConsentController:
query = meeting_consent.select().where(
meeting_consent.c.meeting_id == meeting_id
)
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return [MeetingConsent(**result) for result in results]
async def get_by_meeting_and_user(
@@ -205,7 +211,7 @@ class MeetingConsentController:
meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.user_id == user_id,
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if result is None:
return None
return MeetingConsent(**result) if result else None
@@ -227,14 +233,14 @@ class MeetingConsentController:
consent_timestamp=consent.consent_timestamp,
)
)
await database.execute(query)
await get_database().execute(query)
existing.consent_given = consent.consent_given
existing.consent_timestamp = consent.consent_timestamp
return existing
query = meeting_consent.insert().values(**consent.model_dump())
await database.execute(query)
await get_database().execute(query)
return consent
async def has_any_denial(self, meeting_id: str) -> bool:
@@ -243,7 +249,7 @@ class MeetingConsentController:
meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.consent_given.is_(False),
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
return result is not None

View File

@@ -4,7 +4,7 @@ from typing import Literal
import sqlalchemy as sa
from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
recordings = sa.Table(
@@ -13,7 +13,7 @@ recordings = sa.Table(
sa.Column("id", sa.String, primary_key=True),
sa.Column("bucket_name", sa.String, nullable=False),
sa.Column("object_key", sa.String, nullable=False),
sa.Column("recorded_at", sa.DateTime, nullable=False),
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"status",
sa.String,
@@ -37,12 +37,12 @@ class Recording(BaseModel):
class RecordingController:
async def create(self, recording: Recording):
query = recordings.insert().values(**recording.model_dump())
await database.execute(query)
await get_database().execute(query)
return recording
async def get_by_id(self, id: str) -> Recording:
query = recordings.select().where(recordings.c.id == id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
return Recording(**result) if result else None
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
@@ -50,7 +50,7 @@ class RecordingController:
recordings.c.bucket_name == bucket_name,
recordings.c.object_key == object_key,
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
return Recording(**result) if result else None

View File

@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from sqlite3 import IntegrityError
from typing import Literal
@@ -7,7 +7,7 @@ from fastapi import HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.sql import false, or_
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
rooms = sqlalchemy.Table(
@@ -16,7 +16,7 @@ rooms = sqlalchemy.Table(
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
sqlalchemy.Column("created_at", sqlalchemy.DateTime, nullable=False),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
sqlalchemy.Column(
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
),
@@ -48,7 +48,7 @@ class Room(BaseModel):
id: str = Field(default_factory=generate_uuid4)
name: str
user_id: str
created_at: datetime = Field(default_factory=datetime.utcnow)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
zulip_auto_post: bool = False
zulip_stream: str = ""
zulip_topic: str = ""
@@ -92,7 +92,7 @@ class RoomController:
if return_query:
return query
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return results
async def add(
@@ -125,7 +125,7 @@ class RoomController:
)
query = rooms.insert().values(**room.model_dump())
try:
await database.execute(query)
await get_database().execute(query)
except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique")
return room
@@ -136,7 +136,7 @@ class RoomController:
"""
query = rooms.update().where(rooms.c.id == room.id).values(**values)
try:
await database.execute(query)
await get_database().execute(query)
except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique")
@@ -151,7 +151,7 @@ class RoomController:
query = rooms.select().where(rooms.c.id == room_id)
if "user_id" in kwargs:
query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Room(**result)
@@ -163,7 +163,7 @@ class RoomController:
query = rooms.select().where(rooms.c.name == room_name)
if "user_id" in kwargs:
query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Room(**result)
@@ -175,7 +175,7 @@ class RoomController:
If not found, it will raise a 404 error.
"""
query = rooms.select().where(rooms.c.id == meeting_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Room not found")
@@ -197,7 +197,7 @@ class RoomController:
if user_id is not None and room.user_id != user_id:
return
query = rooms.delete().where(rooms.c.id == room_id)
await database.execute(query)
await get_database().execute(query)
rooms_controller = RoomController()

View File

@@ -9,7 +9,7 @@ import sqlalchemy
import webvtt
from pydantic import BaseModel, Field, constr, field_serializer
from reflector.db import database
from reflector.db import get_database
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql
@@ -207,12 +207,12 @@ class SearchController:
.limit(params.limit)
.offset(params.offset)
)
rs = await database.fetch_all(query)
rs = await get_database().fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
total = await database.fetch_val(count_query)
total = await get_database().fetch_val(count_query)
def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)

View File

@@ -15,7 +15,7 @@ from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.processors.types import Word as ProcessorWord
@@ -41,7 +41,7 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Float),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("title", sqlalchemy.String),
sqlalchemy.Column("short_summary", sqlalchemy.String),
sqlalchemy.Column("long_summary", sqlalchemy.String),
@@ -421,7 +421,7 @@ class TranscriptController:
if return_query:
return query
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
@@ -431,7 +431,7 @@ class TranscriptController:
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Transcript(**result)
@@ -445,7 +445,7 @@ class TranscriptController:
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Transcript(**result)
@@ -463,7 +463,7 @@ class TranscriptController:
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return [Transcript(**result) for result in results]
async def get_by_id_for_http(
@@ -481,7 +481,7 @@ class TranscriptController:
to determine if the user can access the transcript.
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Transcript not found")
@@ -534,7 +534,7 @@ class TranscriptController:
room_id=room_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
await get_database().execute(query)
return transcript
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
@@ -553,7 +553,7 @@ class TranscriptController:
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
await get_database().execute(query)
if mutate:
for key, value in values.items():
setattr(transcript, key, value)
@@ -595,21 +595,21 @@ class TranscriptController:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
await get_database().execute(query)
async def remove_by_recording_id(self, recording_id: str):
"""
Remove a transcript by recording_id
"""
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
await database.execute(query)
await get_database().execute(query)
@asynccontextmanager
async def transaction(self):
"""
A context manager for database transaction
"""
async with database.transaction(isolation="serializable"):
async with get_database().transaction(isolation="serializable"):
yield
async def append_event(

View File

@@ -1,7 +1,9 @@
"""Database utility functions."""
from reflector.db import database
from reflector.db import get_database
def is_postgresql() -> bool:
return database.url.scheme and database.url.scheme.startswith('postgresql')
return get_database().url.scheme and get_database().url.scheme.startswith(
"postgresql"
)

View File

@@ -22,7 +22,7 @@ from celery import chord, current_task, group, shared_task
from pydantic import BaseModel
from structlog import BoundLogger as Logger
from reflector.db import database
from reflector.db import get_database
from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller
@@ -72,6 +72,7 @@ def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
async def run_with_db():
database = get_database()
await database.connect()
try:
return await f(*args, **kwargs)

View File

@@ -6,7 +6,7 @@ This script is used to generate a summary of a meeting notes transcript.
import asyncio
import sys
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum
from textwrap import dedent
from typing import Type, TypeVar
@@ -474,7 +474,7 @@ if __name__ == "__main__":
if args.save:
# write the summary to a file, on the format summary-<iso date>.md
filename = f"summary-{datetime.now().isoformat()}.md"
filename = f"summary-{datetime.now(timezone.utc).isoformat()}.md"
with open(filename, "w", encoding="utf-8") as f:
f.write(sm.as_markdown())

View File

@@ -14,7 +14,9 @@ class Settings(BaseSettings):
CORS_ALLOW_CREDENTIALS: bool = False
# Database
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
DATABASE_URL: str = (
"postgresql+asyncpg://reflector:reflector@localhost:5432/reflector"
)
# local data directory
DATA_DIR: str = "./data"

View File

@@ -9,8 +9,9 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import database, transcripts
from reflector.db import get_database, transcripts
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()

View File

@@ -8,8 +8,9 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import database, transcripts
from reflector.db import get_database, transcripts
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()

View File

@@ -155,7 +155,7 @@ async def process_audio_file_with_diarization(
# For Modal backend, we need to upload the file to S3 first
if diarization_backend == "modal":
from datetime import datetime
from datetime import datetime, timezone
from reflector.storage import get_transcripts_storage
from reflector.utils.s3_temp_file import S3TemporaryFile
@@ -163,7 +163,7 @@ async def process_audio_file_with_diarization(
storage = get_transcripts_storage()
# Generate a unique filename in evaluation folder
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
# Use context manager for automatic cleanup

View File

@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -35,7 +35,7 @@ async def meeting_audio_consent(
meeting_id=meeting_id,
user_id=user_id,
consent_given=request.consent_given,
consent_timestamp=datetime.utcnow(),
consent_timestamp=datetime.now(timezone.utc),
)
updated_consent = await meeting_consent_controller.upsert(consent)

View File

@@ -1,16 +1,16 @@
import logging
import sqlite3
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Annotated, Literal, Optional
import asyncpg.exceptions
from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from fastapi_pagination.ext.databases import apaginate
from pydantic import BaseModel
import reflector.auth as auth
from reflector.db import database
from reflector.db import get_database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.settings import settings
@@ -21,6 +21,14 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
class Room(BaseModel):
id: str
name: str
@@ -83,8 +91,8 @@ async def rooms_list(
user_id = user["sub"] if user else None
return await paginate(
database,
return await apaginate(
get_database(),
await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True
),
@@ -150,7 +158,7 @@ async def rooms_create_meeting(
if not room:
raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.utcnow()
current_time = datetime.now(timezone.utc)
meeting = await meetings_controller.get_active(room=room, current_time=current_time)
if meeting is None:
@@ -166,8 +174,8 @@ async def rooms_create_meeting(
room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"],
host_room_url=whereby_meeting["hostRoomUrl"],
start_date=datetime.fromisoformat(whereby_meeting["startDate"]),
end_date=datetime.fromisoformat(whereby_meeting["endDate"]),
start_date=parse_datetime_with_timezone(whereby_meeting["startDate"]),
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
user_id=user_id,
room=room,
)

View File

@@ -3,12 +3,12 @@ from typing import Annotated, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from fastapi_pagination.ext.databases import apaginate
from jose import jwt
from pydantic import BaseModel, Field, field_serializer
import reflector.auth as auth
from reflector.db import database
from reflector.db import get_database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.search import (
@@ -48,7 +48,7 @@ DOWNLOAD_EXPIRE_MINUTES = 60
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@@ -141,8 +141,8 @@ async def transcripts_list(
user_id = user["sub"] if user else None
return await paginate(
database,
return await apaginate(
get_database(),
await transcripts_controller.get_all(
user_id=user_id,
source_kind=SourceKind(source_kind) if source_kind else None,

View File

@@ -21,6 +21,14 @@ from reflector.whereby import get_room_sessions
logger = structlog.wrap_logger(get_task_logger(__name__))
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
@shared_task
def process_messages():
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
@@ -69,7 +77,7 @@ async def process_recording(bucket_name: str, object_key: str):
# extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}"
recorded_at = datetime.fromisoformat(object_key[37:57])
recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(room_name)
room = await rooms_controller.get_by_id(meeting.room_id)

View File

@@ -62,6 +62,7 @@ class RedisPubSubManager:
class WebsocketManager:
def __init__(self, pubsub_client: RedisPubSubManager = None):
self.rooms: dict = {}
self.tasks: dict = {}
self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
@@ -74,13 +75,17 @@ class WebsocketManager:
await self.pubsub_client.connect()
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
task = asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
self.tasks[id(websocket)] = task
async def send_json(self, room_id: str, message: dict) -> None:
await self.pubsub_client.send_json(room_id, message)
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
self.rooms[room_id].remove(websocket)
task = self.tasks.pop(id(websocket), None)
if task:
task.cancel()
if len(self.rooms[room_id]) == 0:
del self.rooms[room_id]

View File

@@ -1,17 +1,63 @@
import os
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest
# Pytest-docker configuration
@pytest.fixture(scope="session")
def docker_compose_file(pytestconfig):
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"dbname": "reflector_test",
"user": "test_user",
"password": "test_password",
}
@pytest.fixture(scope="function", autouse=True)
@pytest.mark.asyncio
async def setup_database():
from reflector.db import engine, metadata # noqa
async def setup_database(postgres_service):
from reflector.db import engine, metadata, get_database # noqa
metadata.drop_all(bind=engine)
metadata.create_all(bind=engine)
yield
database = get_database()
try:
await database.connect()
yield
finally:
await database.disconnect()
@pytest.fixture
@@ -46,6 +92,20 @@ def dummy_processors():
) # noqa
@pytest.fixture
async def whisper_transcript():
from reflector.processors.audio_transcript_whisper import (
AudioTranscriptWhisperProcessor,
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = AudioTranscriptWhisperProcessor()
yield
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
@@ -181,6 +241,16 @@ def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.fixture
async def client():
from httpx import AsyncClient
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
yield ac
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
@@ -191,13 +261,10 @@ def fake_mp3_upload():
@pytest.fixture
async def fake_transcript_with_topics(tmpdir):
async def fake_transcript_with_topics(tmpdir, client):
import shutil
from pathlib import Path
from httpx import AsyncClient
from reflector.app import app
from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word
from reflector.settings import settings
@@ -206,8 +273,7 @@ async def fake_transcript_with_topics(tmpdir):
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
response = await client.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]

View File

@@ -0,0 +1,13 @@
version: '3.8'
services:
postgres_test:
image: postgres:15
environment:
POSTGRES_DB: reflector_test
POSTGRES_USER: test_user
POSTGRES_PASSWORD: test_password
ports:
- "15432:5432"
command: postgres -c fsync=off -c synchronous_commit=off -c full_page_writes=off
tmpfs:
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g

View File

@@ -1,82 +1,62 @@
"""Tests for full-text search functionality."""
import json
from datetime import datetime
from datetime import datetime, timezone
import pytest
from pydantic import ValidationError
from reflector.db import database
from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
from reflector.db.utils import is_postgresql
@pytest.mark.asyncio
async def test_search_postgresql_only():
await database.connect()
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params)
assert results == []
assert total == 0
try:
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params)
assert results == []
assert total == 0
SearchParameters(query_text="")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
try:
SearchParameters(query_text="")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" ")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
finally:
await database.disconnect()
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" ")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
@pytest.mark.asyncio
async def test_search_input_validation():
await database.connect()
try:
try:
SearchParameters(query_text="")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
SearchParameters(query_text="")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" \t\n ")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
finally:
await database.disconnect()
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" \t\n ")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
@pytest.mark.asyncio
async def test_postgresql_search_with_data():
"""Test full-text search with actual data in PostgreSQL.
Example how to run: DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector_test uv run pytest tests/test_search.py::test_postgresql_search_with_data -v -p no:env
"""
# Skip if not PostgreSQL
if not is_postgresql():
pytest.skip("Test requires PostgreSQL. Set DATABASE_URL=postgresql://...")
await database.connect()
# collision is improbable
test_id = "test-search-e2e-7f3a9b2c"
try:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
@@ -85,7 +65,7 @@ async def test_postgresql_search_with_data():
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(),
"created_at": datetime.now(timezone.utc),
"short_summary": "Team discussed search implementation",
"long_summary": "The engineering team met to plan the search feature",
"topics": json.dumps([]),
@@ -112,7 +92,7 @@ The search feature should support complex queries with ranking.
We need to implement PostgreSQL tsvector for better performance.""",
}
await database.execute(transcripts.insert().values(**test_data))
await get_database().execute(transcripts.insert().values(**test_data))
# Test 1: Search for a word in title
params = SearchParameters(query_text="planning")
@@ -141,7 +121,6 @@ We need to implement PostgreSQL tsvector for better performance.""",
assert test_result.title == "Engineering Planning Meeting Q4 2024"
assert test_result.status == "completed"
assert test_result.duration == 1800.0
assert test_result.source_kind == "room"
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
# Test 5: Search with OR operator
@@ -159,5 +138,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
assert found, "Should find test transcript by exact phrase"
finally:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
await database.disconnect()
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()

View File

@@ -1,147 +1,128 @@
from contextlib import asynccontextmanager
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_create():
from reflector.app import app
async def test_transcript_create(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["status"] == "idle"
assert response.json()["locked"] is False
assert response.json()["id"] is not None
assert response.json()["created_at"] is not None
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["status"] == "idle"
assert response.json()["locked"] is False
assert response.json()["id"] is not None
assert response.json()["created_at"] is not None
# ensure some fields are not returned
assert "topics" not in response.json()
assert "events" not in response.json()
# ensure some fields are not returned
assert "topics" not in response.json()
assert "events" not in response.json()
@pytest.mark.asyncio
async def test_transcript_get_update_name():
from reflector.app import app
async def test_transcript_get_update_name(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
response = await client.patch(f"/transcripts/{tid}", json={"name": "test2"})
assert response.status_code == 200
assert response.json()["name"] == "test2"
response = await ac.patch(f"/transcripts/{tid}", json={"name": "test2"})
assert response.status_code == 200
assert response.json()["name"] == "test2"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test2"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_get_update_locked():
from reflector.app import app
async def test_transcript_get_update_locked(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["locked"] is False
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["locked"] is False
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is False
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is False
response = await client.patch(f"/transcripts/{tid}", json={"locked": True})
assert response.status_code == 200
assert response.json()["locked"] is True
response = await ac.patch(f"/transcripts/{tid}", json={"locked": True})
assert response.status_code == 200
assert response.json()["locked"] is True
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is True
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is True
@pytest.mark.asyncio
async def test_transcript_get_update_summary():
from reflector.app import app
async def test_transcript_get_update_summary(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
response = await client.patch(
f"/transcripts/{tid}",
json={"long_summary": "test_long", "short_summary": "test_short"},
)
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await ac.patch(
f"/transcripts/{tid}",
json={"long_summary": "test_long", "short_summary": "test_short"},
)
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
@pytest.mark.asyncio
async def test_transcript_get_update_title():
from reflector.app import app
async def test_transcript_get_update_title(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["title"] is None
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["title"] is None
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] is None
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] is None
response = await client.patch(f"/transcripts/{tid}", json={"title": "test_title"})
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"})
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] == "test_title"
@pytest.mark.asyncio
async def test_transcripts_list_anonymous():
async def test_transcripts_list_anonymous(client):
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
from reflector.settings import settings
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.get("/transcripts")
assert response.status_code == 401
response = await client.get("/transcripts")
assert response.status_code == 401
# if public mode, it should be allowed
try:
settings.PUBLIC_MODE = True
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.get("/transcripts")
assert response.status_code == 200
response = await client.get("/transcripts")
assert response.status_code == 200
finally:
settings.PUBLIC_MODE = False
@@ -197,67 +178,59 @@ async def authenticated_client2():
@pytest.mark.asyncio
async def test_transcripts_list_authenticated(authenticated_client):
async def test_transcripts_list_authenticated(authenticated_client, client):
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200
assert response.json()["name"] == "testxx1"
response = await client.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200
assert response.json()["name"] == "testxx1"
response = await ac.post("/transcripts", json={"name": "testxx2"})
assert response.status_code == 200
assert response.json()["name"] == "testxx2"
response = await client.post("/transcripts", json={"name": "testxx2"})
assert response.status_code == 200
assert response.json()["name"] == "testxx2"
response = await ac.get("/transcripts")
assert response.status_code == 200
assert len(response.json()["items"]) >= 2
names = [t["name"] for t in response.json()["items"]]
assert "testxx1" in names
assert "testxx2" in names
response = await client.get("/transcripts")
assert response.status_code == 200
assert len(response.json()["items"]) >= 2
names = [t["name"] for t in response.json()["items"]]
assert "testxx1" in names
assert "testxx2" in names
@pytest.mark.asyncio
async def test_transcript_delete():
from reflector.app import app
async def test_transcript_delete(client):
response = await client.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200
assert response.json()["name"] == "testdel1"
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200
assert response.json()["name"] == "testdel1"
tid = response.json()["id"]
response = await client.delete(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["status"] == "ok"
tid = response.json()["id"]
response = await ac.delete(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["status"] == "ok"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 404
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_transcript_mark_reviewed():
from reflector.app import app
async def test_transcript_mark_reviewed(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
response = await client.patch(f"/transcripts/{tid}", json={"reviewed": True})
assert response.status_code == 200
assert response.json()["reviewed"] is True
response = await ac.patch(f"/transcripts/{tid}", json={"reviewed": True})
assert response.status_code == 200
assert response.json()["reviewed"] is True
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["reviewed"] is True
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["reviewed"] is True

View File

@@ -2,20 +2,17 @@ import shutil
from pathlib import Path
import pytest
from httpx import AsyncClient
@pytest.fixture
async def fake_transcript(tmpdir):
from reflector.app import app
async def fake_transcript(tmpdir, client):
from reflector.settings import settings
from reflector.views.transcripts import transcripts_controller
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
response = await client.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
@@ -39,17 +36,17 @@ async def fake_transcript(tmpdir):
["/mp3", "audio/mpeg"],
],
)
async def test_transcript_audio_download(fake_transcript, url_suffix, content_type):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
async def test_transcript_audio_download(
fake_transcript, url_suffix, content_type, client
):
response = await client.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test get 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
response = await client.get(
f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
)
assert response.status_code == 404
@@ -61,18 +58,16 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
],
)
async def test_transcript_audio_download_head(
fake_transcript, url_suffix, content_type
fake_transcript, url_suffix, content_type, client
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
response = await client.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test head 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
response = await client.head(
f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
)
assert response.status_code == 404
@@ -84,12 +79,9 @@ async def test_transcript_audio_download_head(
],
)
async def test_transcript_audio_download_range(
fake_transcript, url_suffix, content_type
fake_transcript, url_suffix, content_type, client
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
response = await client.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=0-100"},
)
@@ -107,12 +99,9 @@ async def test_transcript_audio_download_range(
],
)
async def test_transcript_audio_download_range_with_seek(
fake_transcript, url_suffix, content_type
fake_transcript, url_suffix, content_type, client
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
response = await client.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=100-"},
)
@@ -122,13 +111,10 @@ async def test_transcript_audio_download_range_with_seek(
@pytest.mark.asyncio
async def test_transcript_delete_with_audio(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.delete(f"/transcripts/{fake_transcript.id}")
async def test_transcript_delete_with_audio(fake_transcript, client):
response = await client.delete(f"/transcripts/{fake_transcript.id}")
assert response.status_code == 200
assert response.json()["status"] == "ok"
response = await ac.get(f"/transcripts/{fake_transcript.id}")
response = await client.get(f"/transcripts/{fake_transcript.id}")
assert response.status_code == 404

View File

@@ -1,164 +1,151 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_participants():
from reflector.app import app
async def test_transcript_participants(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
# create a participant
transcript_id = response.json()["id"]
response = await client.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] is None
assert response.json()["name"] == "test"
# create a participant
transcript_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] is None
assert response.json()["name"] == "test"
# create another one with a speaker
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] == 1
assert response.json()["name"] == "test2"
# create another one with a speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] == 1
assert response.json()["name"] == "test2"
# get all participants via transcript
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 2
# get all participants via transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 2
# get participants via participants endpoint
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
assert len(response.json()) == 2
# get participants via participants endpoint
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
assert len(response.json()) == 2
@pytest.mark.asyncio
async def test_transcript_participants_same_speaker():
from reflector.app import app
async def test_transcript_participants_same_speaker(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create another one with the same speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 400
# create another one with the same speaker
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_transcript_participants_update_name():
from reflector.app import app
async def test_transcript_participants_update_name(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# update the participant
participant_id = response.json()["id"]
response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"},
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# update the participant
participant_id = response.json()["id"]
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"},
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated
response = await client.get(
f"/transcripts/{transcript_id}/participants/{participant_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated in transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2"
# verify the participant was updated in transcript
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_participants_update_speaker():
from reflector.app import app
async def test_transcript_participants_update_speaker(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create another participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# create another participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# update the participant, refused as speaker is already taken
response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 400
# update the participant, refused as speaker is already taken
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 400
# delete the participant 1
response = await client.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}"
)
assert response.status_code == 200
# delete the participant 1
response = await ac.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}"
)
assert response.status_code == 200
# update the participant 2 again, should be accepted now
response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 200
# update the participant 2 again, should be accepted now
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 200
# ensure participant2 name is still there
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
assert response.json()["speaker"] == 1
# ensure participant2 name is still there
response = await client.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
assert response.json()["speaker"] == 1

View File

@@ -2,7 +2,25 @@ import asyncio
import time
import pytest
from httpx import AsyncClient
from httpx import ASGITransport, AsyncClient
@pytest.fixture
async def app_lifespan():
from asgi_lifespan import LifespanManager
from reflector.app import app
async with LifespanManager(app) as manager:
yield manager.app
@pytest.fixture
async def client(app_lifespan):
yield AsyncClient(
transport=ASGITransport(app=app_lifespan),
base_url="http://test/v1",
)
@pytest.mark.usefixtures("setup_database")
@@ -11,23 +29,21 @@ from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
whisper_transcript,
dummy_llm,
dummy_processors,
dummy_diarization,
dummy_storage,
client,
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript
response = await ac.post("/transcripts", json={"name": "test"})
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["status"] == "idle"
tid = response.json()["id"]
# upload mp3
response = await ac.post(
response = await client.post(
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
files={
"chunk": (
@@ -45,7 +61,7 @@ async def test_transcript_process(
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
@@ -54,7 +70,7 @@ async def test_transcript_process(
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# restart the processing
response = await ac.post(
response = await client.post(
f"/transcripts/{tid}/process",
)
assert response.status_code == 200
@@ -65,7 +81,7 @@ async def test_transcript_process(
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
@@ -80,7 +96,7 @@ async def test_transcript_process(
assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"]

View File

@@ -10,7 +10,6 @@ import time
from pathlib import Path
import pytest
from httpx import AsyncClient
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
@@ -50,23 +49,69 @@ class ThreadedUvicorn:
@pytest.fixture
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
import threading
from reflector.app import app
from reflector.db import get_database
from reflector.settings import settings
DATA_DIR = settings.DATA_DIR
settings.DATA_DIR = Path(tmpdir)
# start server
# start server in a separate thread with its own event loop
host = "127.0.0.1"
port = 1255
config = Config(app=app, host=host, port=port)
server = ThreadedUvicorn(config)
await server.start()
server_started = threading.Event()
server_exception = None
server_instance = None
yield (server, host, port)
def run_server():
nonlocal server_exception, server_instance
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
# Initialize database connection in this event loop
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait for server to start
server_started.wait(timeout=30)
if server_exception:
raise server_exception
# Wait a bit more for the server to be fully ready
time.sleep(1)
yield server_instance, host, port
# Stop server
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
server.stop()
settings.DATA_DIR = DATA_DIR
@@ -89,6 +134,7 @@ async def test_transcript_rtc_and_websocket(
dummy_storage,
fake_mp3_upload,
appserver,
client,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -97,8 +143,7 @@ async def test_transcript_rtc_and_websocket(
# create a transcript
base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url)
response = await ac.post("/transcripts", json={"name": "Test RTC"})
response = await client.post("/transcripts", json={"name": "Test RTC"})
assert response.status_code == 200
tid = response.json()["id"]
@@ -143,11 +188,11 @@ async def test_transcript_rtc_and_websocket(
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
await stream_client.start()
timeout = 20
while not client.is_ended():
timeout = 120
while not stream_client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
@@ -155,14 +200,14 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
await client.stop()
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
await stream_client.stop()
# wait the processing to finish
timeout = 20
timeout = 120
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
@@ -215,7 +260,7 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
waveform_resp = await client.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
@@ -235,7 +280,7 @@ async def test_transcript_rtc_and_websocket(
assert "DURATION" in eventnames
# check that audio/mp3 is available
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
audio_resp = await client.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@@ -254,6 +299,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_storage,
fake_mp3_upload,
appserver,
client,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -263,8 +309,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
# create a transcript
base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url)
response = await ac.post(
response = await client.post(
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
)
assert response.status_code == 200
@@ -311,11 +356,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
await stream_client.start()
timeout = 20
while not client.is_ended():
timeout = 120
while not stream_client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
@@ -323,18 +368,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
await stream_client.stop()
# wait the processing to finish
timeout = 20
timeout = 120
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
break

View File

@@ -1,401 +1,390 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_reassign_speaker(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check initial topics of the transcript
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# reassign speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 2,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# reassign speaker, middle of 2 topics
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 2,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 4,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# reassign speaker, everything
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 4,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 4
assert topics[0]["words"][1]["speaker"] == 4
assert topics[1]["words"][0]["speaker"] == 4
assert topics[1]["words"][1]["speaker"] == 4
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4
# check through words
assert topics[0]["words"][0]["speaker"] == 4
assert topics[0]["words"][1]["speaker"] == 4
assert topics[1]["words"][0]["speaker"] == 4
assert topics[1]["words"][1]["speaker"] == 4
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4
@pytest.mark.asyncio
async def test_transcript_merge_speaker(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check initial topics of the transcript
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# reassign speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# merge speakers
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/merge",
json={
"speaker_from": 1,
"speaker_to": 0,
},
)
assert response.status_code == 200
# merge speakers
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/merge",
json={
"speaker_from": 1,
"speaker_to": 0,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
@pytest.mark.asyncio
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_reassign_with_participant(
fake_transcript_with_topics, client
):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# create 2 participants
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 1",
},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create 2 participants
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 1",
},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 2",
},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 2",
},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# check participants speakers
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] is None
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check participants speakers
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] is None
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check initial topics of the transcript
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker from a participant
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# reassign speaker from a participant
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign participant, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant2_id,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# reassign participant, middle of 2 topics
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant2_id,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] == 2
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] == 2
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# reassign speaker, everything
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 1
assert topics[1]["words"][1]["speaker"] == 1
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 1
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 1
assert topics[1]["words"][1]["speaker"] == 1
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 1
@pytest.mark.asyncio
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# try reassign without any participant_id or speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassign without any participant_id or speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with both participant_id and speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with both participant_id and speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with non-existing participant_id
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 404
# try reassing with non-existing participant_id
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 404

View File

@@ -1,26 +1,22 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_topics(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_topics(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200
assert len(response.json()) == 2
topic_id = response.json()[0]["id"]
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200
assert len(response.json()) == 2
topic_id = response.json()[0]["id"]
# get words per speakers
response = await ac.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
)
assert response.status_code == 200
data = response.json()
assert len(data["words_per_speaker"]) == 1
assert data["words_per_speaker"][0]["speaker"] == 0
assert len(data["words_per_speaker"][0]["words"]) == 2
# get words per speakers
response = await client.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
)
assert response.status_code == 200
data = response.json()
assert len(data["words_per_speaker"]) == 1
assert data["words_per_speaker"][0]["speaker"] == 0
assert len(data["words_per_speaker"][0]["words"]) == 2

View File

@@ -1,63 +1,53 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_create_default_translation():
from reflector.app import app
async def test_transcript_create_default_translation(client):
response = await client.post("/transcripts", json={"name": "test en"})
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test en"})
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
@pytest.mark.asyncio
async def test_transcript_create_en_fr_translation():
from reflector.app import app
async def test_transcript_create_en_fr_translation(client):
response = await client.post(
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
tid = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
@pytest.mark.asyncio
async def test_transcript_create_fr_en_translation():
from reflector.app import app
async def test_transcript_create_fr_en_translation(client):
response = await client.post(
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"

View File

@@ -2,7 +2,6 @@ import asyncio
import time
import pytest
from httpx import AsyncClient
@pytest.mark.usefixtures("setup_database")
@@ -15,19 +14,16 @@ async def test_transcript_upload_file(
dummy_processors,
dummy_diarization,
dummy_storage,
client,
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript
response = await ac.post("/transcripts", json={"name": "test"})
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["status"] == "idle"
tid = response.json()["id"]
# upload mp3
response = await ac.post(
response = await client.post(
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
files={
"chunk": (
@@ -45,7 +41,7 @@ async def test_transcript_upload_file(
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
@@ -60,7 +56,7 @@ async def test_transcript_upload_file(
assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"]

View File

@@ -2,7 +2,7 @@
import pytest
from reflector.db import database
from reflector.db import get_database
from reflector.db.transcripts import (
SourceKind,
TranscriptController,
@@ -26,7 +26,7 @@ class TestWebVTTAutoUpdate:
)
try:
result = await database.fetch_one(
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
@@ -58,7 +58,7 @@ class TestWebVTTAutoUpdate:
await controller.upsert_topic(transcript, topic)
result = await database.fetch_one(
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
@@ -99,7 +99,7 @@ class TestWebVTTAutoUpdate:
await controller.update(transcript, {"topics": topics_data})
# Fetch from DB
result = await database.fetch_one(
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
@@ -141,7 +141,7 @@ class TestWebVTTAutoUpdate:
await controller.update(transcript, values)
# Fetch from DB
result = await database.fetch_one(
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
@@ -216,7 +216,7 @@ class TestWebVTTAutoUpdate:
await controller.update(transcript, values)
# Fetch from DB
result = await database.fetch_one(
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)

2739
server/uv.lock generated

File diff suppressed because it is too large Load Diff