From f3ae187274277a58049b5f2a14a34915f31b243e Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 15 Jul 2025 20:46:19 -0600 Subject: [PATCH] fix: waveform can generate NaN in json database (#481) * refactor: fixes transcript duration type, NaN in waveform, and prepare for postgres migration * fix: ensure we don't have NaN in waveform * fix: missing assertionerror Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com> * fix: potential empty array --------- Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com> --- server/migrations/env.py | 4 + ...f0b60a9d34_fix_transcript_duration_type.py | 40 ++++++++++ ...92678ba2_fix_transcript_json_nan_values.py | 73 +++++++++++++++++++ ...a9c9c229ee36_transcript_composite_index.py | 39 ++++++++++ server/reflector/db/__init__.py | 7 +- server/reflector/db/transcripts.py | 22 +++--- server/reflector/utils/audio_waveform.py | 5 +- 7 files changed, 177 insertions(+), 13 deletions(-) create mode 100644 server/migrations/versions/2cf0b60a9d34_fix_transcript_duration_type.py create mode 100644 server/migrations/versions/88d292678ba2_fix_transcript_json_nan_values.py create mode 100644 server/migrations/versions/a9c9c229ee36_transcript_composite_index.py diff --git a/server/migrations/env.py b/server/migrations/env.py index 226b95b5..6960f0aa 100644 --- a/server/migrations/env.py +++ b/server/migrations/env.py @@ -24,6 +24,10 @@ target_metadata = metadata # ... etc. +# don't use asyncpg for the moment +settings.DATABASE_URL = settings.DATABASE_URL.replace("+asyncpg", "") + + def run_migrations_offline() -> None: """Run migrations in 'offline' mode. diff --git a/server/migrations/versions/2cf0b60a9d34_fix_transcript_duration_type.py b/server/migrations/versions/2cf0b60a9d34_fix_transcript_duration_type.py new file mode 100644 index 00000000..5ce9ea89 --- /dev/null +++ b/server/migrations/versions/2cf0b60a9d34_fix_transcript_duration_type.py @@ -0,0 +1,40 @@ +"""fix transcript duration type + +Revision ID: 2cf0b60a9d34 +Revises: ccd68dc784ff +Create Date: 2025-07-15 16:53:40.397394 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '2cf0b60a9d34' +down_revision: Union[str, None] = 'ccd68dc784ff' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('transcript', schema=None) as batch_op: + batch_op.alter_column('duration', + existing_type=sa.INTEGER(), + type_=sa.Float(), + existing_nullable=True) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('transcript', schema=None) as batch_op: + batch_op.alter_column('duration', + existing_type=sa.Float(), + type_=sa.INTEGER(), + existing_nullable=True) + + # ### end Alembic commands ### diff --git a/server/migrations/versions/88d292678ba2_fix_transcript_json_nan_values.py b/server/migrations/versions/88d292678ba2_fix_transcript_json_nan_values.py new file mode 100644 index 00000000..e28416bd --- /dev/null +++ b/server/migrations/versions/88d292678ba2_fix_transcript_json_nan_values.py @@ -0,0 +1,73 @@ +"""fix_transcript_json_nan_values + +Revision ID: 88d292678ba2 +Revises: 2cf0b60a9d34 +Create Date: 2025-07-15 19:30:19.876332 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "88d292678ba2" +down_revision: Union[str, None] = "2cf0b60a9d34" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + import json + import re + from sqlalchemy import text + + # Get database connection + conn = op.get_bind() + + # Fetch all transcript records with events data + result = conn.execute( + text("SELECT id, events FROM transcript WHERE events IS NOT NULL") + ) + + def fix_nan(obj): + if isinstance(obj, dict): + for key, value in obj.items(): + if isinstance(value, (dict, list)): + fix_nan(value) + elif isinstance(value, float) and value != value: + obj[key] = None + elif isinstance(obj, list): + for i in range(len(obj)): + if isinstance(obj[i], (dict, list)): + fix_nan(obj[i]) + elif isinstance(obj[i], float) and obj[i] != obj[i]: + obj[i] = None + + for transcript_id, events in result: + if not events: + continue + if "NaN" not in events: + continue + + try: + jevents = json.loads(events) + fix_nan(jevents) + fixed_events = json.dumps(jevents) + assert "NaN" not in fixed_events + except (json.JSONDecodeError, AssertionError) as e: + print(f"Warning: Invalid JSON for transcript {transcript_id}, skipping: {e}") + continue + + # Update the record with fixed JSON + conn.execute( + text("UPDATE transcript SET events = :events WHERE id = :id"), + {"events": fixed_events, "id": transcript_id}, + ) + + +def downgrade() -> None: + # No downgrade needed - this is a data fix + pass diff --git a/server/migrations/versions/a9c9c229ee36_transcript_composite_index.py b/server/migrations/versions/a9c9c229ee36_transcript_composite_index.py new file mode 100644 index 00000000..a206732a --- /dev/null +++ b/server/migrations/versions/a9c9c229ee36_transcript_composite_index.py @@ -0,0 +1,39 @@ +"""transcript composite index + +Revision ID: a9c9c229ee36 +Revises: 88d292678ba2 +Create Date: 2025-07-15 20:09:40.253018 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "a9c9c229ee36" +down_revision: Union[str, None] = "88d292678ba2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("transcript", schema=None) as batch_op: + batch_op.create_index( + "idx_transcript_user_id_recording_id", + ["user_id", "recording_id"], + unique=False, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("transcript", schema=None) as batch_op: + batch_op.drop_index("idx_transcript_user_id_recording_id") + + # ### end Alembic commands ### diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index 5693c06b..c3e08a2f 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -12,9 +12,10 @@ import reflector.db.recordings # noqa import reflector.db.rooms # noqa import reflector.db.transcripts # noqa -engine = sqlalchemy.create_engine( - settings.DATABASE_URL, connect_args={"check_same_thread": False} -) +kwargs = {} +if "sqlite" in settings.DATABASE_URL: + kwargs["connect_args"] = {"check_same_thread": False} +engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs) @subscribers_startup.append diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 85d4bbb2..dc832850 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -32,16 +32,16 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("name", sqlalchemy.String), sqlalchemy.Column("status", sqlalchemy.String), sqlalchemy.Column("locked", sqlalchemy.Boolean), - sqlalchemy.Column("duration", sqlalchemy.Integer), + sqlalchemy.Column("duration", sqlalchemy.Float), sqlalchemy.Column("created_at", sqlalchemy.DateTime), - sqlalchemy.Column("title", sqlalchemy.String, nullable=True), - sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), - sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("title", sqlalchemy.String), + sqlalchemy.Column("short_summary", sqlalchemy.String), + sqlalchemy.Column("long_summary", sqlalchemy.String), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("participants", sqlalchemy.JSON), - sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), - sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column("source_language", sqlalchemy.String), + sqlalchemy.Column("target_language", sqlalchemy.String), sqlalchemy.Column( "reviewed", sqlalchemy.Boolean, nullable=False, server_default=false() ), @@ -63,8 +63,8 @@ transcripts = sqlalchemy.Table( "meeting_id", sqlalchemy.String, ), - sqlalchemy.Column("recording_id", sqlalchemy.String, nullable=True), - sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer, nullable=True), + sqlalchemy.Column("recording_id", sqlalchemy.String), + sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer), sqlalchemy.Column( "source_kind", Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]), @@ -73,10 +73,11 @@ transcripts = sqlalchemy.Table( # indicative field: whether associated audio is deleted # the main "audio deleted" is the presence of the audio itself / consents not-given # same field could've been in recording/meeting, and it's maybe even ok to dupe it at need - sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean, nullable=True), + sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), sqlalchemy.Index("idx_transcript_user_id", "user_id"), sqlalchemy.Index("idx_transcript_created_at", "created_at"), + sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"), ) @@ -336,6 +337,7 @@ class TranscriptController: .join(meetings, recordings.c.meeting_id == meetings.c.id, isouter=True) .join(rooms, meetings.c.room_id == rooms.c.id, isouter=True) ) + if user_id: query = query.where( or_(transcripts.c.user_id == user_id, rooms.c.is_shared) @@ -377,6 +379,8 @@ class TranscriptController: if filter_recording: query = query.filter(transcripts.c.status != "recording") + # print(query.compile(compile_kwargs={"literal_binds": True})) + if return_query: return query diff --git a/server/reflector/utils/audio_waveform.py b/server/reflector/utils/audio_waveform.py index d9f6b05c..b4412e05 100644 --- a/server/reflector/utils/audio_waveform.py +++ b/server/reflector/utils/audio_waveform.py @@ -57,7 +57,10 @@ def get_audio_waveform(path: Path | str, segments_count: int = 256) -> list[int] # number of decimals to use when rounding the peak value digits = 2 - volumes = np.round(volumes / volumes.max(), digits) + if len(volumes) > 0 and volumes.max() > 0: + volumes = np.round(volumes / volumes.max(), digits) + else: + volumes = np.zeros_like(volumes) if len(volumes) > 0 else np.array([]) return volumes.tolist()