mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
feat: search backend (#537)
* docs: transient docs * chore: cleanup * webvtt WIP * webvtt field * chore: webvtt tests comments * chore: remove useless tests * feat: search TASK.md * feat: full text search by title/webvtt * chore: search api task * feat: search api * feat: search API * chore: rm task md * chore: roll back unnecessary validators * chore: pr review WIP * chore: pr review WIP * chore: pr review * chore: top imports * feat: better lint + ci * feat: better lint + ci * feat: better lint + ci * feat: better lint + ci * chore: lint * chore: lint * fix: db datetime definitions * fix: flush() params * fix: update transcript mutability expectation / test * fix: update transcript mutability expectation / test * chore: auto review * chore: new controller extraction * chore: new controller extraction * chore: cleanup * chore: review WIP * chore: pr WIP * chore: remove ci lint * chore: openapi regeneration * chore: openapi regeneration * chore: postgres test doc * fix: .dockerignore for arm binaries * fix: .dockerignore for arm binaries * fix: cap test loops * fix: cap test loops * fix: cap test loops * fix: get_transcript_topics * chore: remove flow.md docs and claude guidance * chore: remove claude.md db doc * chore: remove claude.md db doc * chore: remove claude.md db doc * chore: remove claude.md db doc
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -14,3 +14,4 @@ data/
|
||||
www/REFACTOR.md
|
||||
www/reload-frontend
|
||||
server/test.sqlite
|
||||
CLAUDE.local.md
|
||||
@@ -3,10 +3,10 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: yarn-format
|
||||
name: run yarn format
|
||||
- id: format
|
||||
name: run format
|
||||
language: system
|
||||
entry: bash -c 'cd www && yarn format'
|
||||
entry: bash -c 'cd www && npx prettier --write .'
|
||||
pass_filenames: false
|
||||
files: ^www/
|
||||
|
||||
@@ -23,8 +23,7 @@ repos:
|
||||
- id: ruff
|
||||
args:
|
||||
- --fix
|
||||
- --select
|
||||
- I,F401
|
||||
# Uses select rules from server/pyproject.toml
|
||||
files: ^server/
|
||||
- id: ruff-format
|
||||
files: ^server/
|
||||
|
||||
@@ -44,6 +44,7 @@ services:
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ./www:/app/
|
||||
- /app/node_modules
|
||||
env_file:
|
||||
- ./www/.env.local
|
||||
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
Generic single-database configuration.
|
||||
Generic single-database configuration.
|
||||
|
||||
Both data migrations and schema migrations must be in migrations.
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add_webvtt_field_to_transcript
|
||||
|
||||
Revision ID: 0bc0f3ff0111
|
||||
Revises: b7df9609542c
|
||||
Create Date: 2025-08-05 19:36:41.740957
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = '0bc0f3ff0111'
|
||||
down_revision: Union[str, None] = 'b7df9609542c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('transcript',
|
||||
sa.Column('webvtt', sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('transcript', 'webvtt')
|
||||
@@ -0,0 +1,47 @@
|
||||
"""add_full_text_search
|
||||
|
||||
Revision ID: 116b2f287eab
|
||||
Revises: 0bc0f3ff0111
|
||||
Create Date: 2025-08-07 11:27:38.473517
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = '116b2f287eab'
|
||||
down_revision: Union[str, None] = '0bc0f3ff0111'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
if conn.dialect.name != 'postgresql':
|
||||
return
|
||||
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
op.create_index(
|
||||
'idx_transcript_search_vector_en',
|
||||
'transcript',
|
||||
['search_vector_en'],
|
||||
postgresql_using='gin'
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
if conn.dialect.name != 'postgresql':
|
||||
return
|
||||
|
||||
op.drop_index('idx_transcript_search_vector_en', table_name='transcript')
|
||||
op.drop_column('transcript', 'search_vector_en')
|
||||
@@ -0,0 +1,109 @@
|
||||
"""populate_webvtt_from_topics
|
||||
|
||||
Revision ID: 8120ebc75366
|
||||
Revises: 116b2f287eab
|
||||
Create Date: 2025-08-11 19:11:01.316947
|
||||
|
||||
"""
|
||||
import json
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '8120ebc75366'
|
||||
down_revision: Union[str, None] = '116b2f287eab'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def topics_to_webvtt(topics):
|
||||
"""Convert topics list to WebVTT format string."""
|
||||
if not topics:
|
||||
return None
|
||||
|
||||
lines = ["WEBVTT", ""]
|
||||
|
||||
for topic in topics:
|
||||
start_time = format_timestamp(topic.get("start"))
|
||||
end_time = format_timestamp(topic.get("end"))
|
||||
text = topic.get("text", "").strip()
|
||||
|
||||
if start_time and end_time and text:
|
||||
lines.append(f"{start_time} --> {end_time}")
|
||||
lines.append(text)
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def format_timestamp(seconds):
|
||||
"""Format seconds to WebVTT timestamp format (HH:MM:SS.mmm)."""
|
||||
if seconds is None:
|
||||
return None
|
||||
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = seconds % 60
|
||||
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:06.3f}"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Populate WebVTT field for all transcripts with topics."""
|
||||
|
||||
# Get connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Query all transcripts with topics
|
||||
result = connection.execute(
|
||||
text("SELECT id, topics FROM transcript WHERE topics IS NOT NULL")
|
||||
)
|
||||
|
||||
rows = result.fetchall()
|
||||
print(f"Found {len(rows)} transcripts with topics")
|
||||
|
||||
updated_count = 0
|
||||
error_count = 0
|
||||
|
||||
for row in rows:
|
||||
transcript_id = row[0]
|
||||
topics_data = row[1]
|
||||
|
||||
if not topics_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse JSON if it's a string
|
||||
if isinstance(topics_data, str):
|
||||
topics_data = json.loads(topics_data)
|
||||
|
||||
# Convert topics to WebVTT format
|
||||
webvtt_content = topics_to_webvtt(topics_data)
|
||||
|
||||
if webvtt_content:
|
||||
# Update the webvtt field
|
||||
connection.execute(
|
||||
text("UPDATE transcript SET webvtt = :webvtt WHERE id = :id"),
|
||||
{"webvtt": webvtt_content, "id": transcript_id}
|
||||
)
|
||||
updated_count += 1
|
||||
print(f"✓ Updated transcript {transcript_id}")
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
print(f"✗ Error updating transcript {transcript_id}: {e}")
|
||||
|
||||
print(f"\nMigration complete!")
|
||||
print(f" Updated: {updated_count}")
|
||||
print(f" Errors: {error_count}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Clear WebVTT field for all transcripts."""
|
||||
op.execute(
|
||||
text("UPDATE transcript SET webvtt = NULL")
|
||||
)
|
||||
@@ -40,6 +40,7 @@ dependencies = [
|
||||
"llama-index>=0.12.52",
|
||||
"llama-index-llms-openai-like>=0.4.0",
|
||||
"pytest-env>=1.1.5",
|
||||
"webvtt-py>=0.5.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -92,5 +93,12 @@ addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"I", # isort - import sorting
|
||||
"F401", # unused imports
|
||||
"PLC0415", # import-outside-top-level - detect inline imports
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"reflector/processors/summary/summary_builder.py" = ["E501"]
|
||||
|
||||
233
server/reflector/db/search.py
Normal file
233
server/reflector/db/search.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Search functionality for transcripts and other entities."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict
|
||||
|
||||
import sqlalchemy
|
||||
import webvtt
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 20
|
||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
||||
DEFAULT_MAX_SNIPPETS = 3
|
||||
|
||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||
|
||||
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
|
||||
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
|
||||
SearchOffset = Annotated[
|
||||
SearchOffsetBase, Field(description="Number of results to skip")
|
||||
]
|
||||
SearchTotal = Annotated[
|
||||
SearchTotalBase, Field(description="Total number of search results")
|
||||
]
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
"""Validated search parameters for full-text search."""
|
||||
|
||||
query_text: SearchQuery
|
||||
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
|
||||
offset: SearchOffset = 0
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
|
||||
|
||||
class SearchResultDB(BaseModel):
|
||||
"""Intermediate model for validating raw database results."""
|
||||
|
||||
id: str = Field(..., min_length=1)
|
||||
created_at: datetime
|
||||
status: str = Field(..., min_length=1)
|
||||
duration: float | None = Field(None, ge=0)
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
source_kind: SourceKind
|
||||
room_id: str | None = None
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Public search result model with computed fields."""
|
||||
|
||||
id: str = Field(..., min_length=1)
|
||||
title: str | None = None
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
created_at: datetime
|
||||
status: str = Field(..., min_length=1)
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
||||
search_snippets: list[str] = Field(
|
||||
description="Text snippets around search matches"
|
||||
)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
if dt.tzinfo is None:
|
||||
return dt.isoformat() + "Z"
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
|
||||
@staticmethod
|
||||
def _extract_webvtt_text(webvtt_content: str) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
if not webvtt_content:
|
||||
return ""
|
||||
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _generate_snippets(
|
||||
text: str,
|
||||
q: SearchQuery,
|
||||
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate multiple snippets around all occurrences of search term."""
|
||||
if not text or not q:
|
||||
return []
|
||||
|
||||
snippets = []
|
||||
lower_text = text.lower()
|
||||
search_lower = q.lower()
|
||||
|
||||
last_snippet_end = 0
|
||||
start_pos = 0
|
||||
|
||||
while len(snippets) < max_snippets:
|
||||
match_pos = lower_text.find(search_lower, start_pos)
|
||||
|
||||
if match_pos == -1:
|
||||
if not snippets and search_lower.split():
|
||||
first_word = search_lower.split()[0]
|
||||
match_pos = lower_text.find(first_word, start_pos)
|
||||
if match_pos == -1:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
||||
snippet_end = min(
|
||||
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if snippet_start < last_snippet_end:
|
||||
start_pos = match_pos + len(search_lower)
|
||||
continue
|
||||
|
||||
snippet = text[snippet_start:snippet_end]
|
||||
|
||||
if snippet_start > 0:
|
||||
snippet = "..." + snippet
|
||||
if snippet_end < len(text):
|
||||
snippet = snippet + "..."
|
||||
|
||||
snippet = snippet.strip()
|
||||
|
||||
if snippet:
|
||||
snippets.append(snippet)
|
||||
last_snippet_end = snippet_end
|
||||
|
||||
start_pos = match_pos + len(search_lower)
|
||||
if start_pos >= len(text):
|
||||
break
|
||||
|
||||
return snippets
|
||||
|
||||
@classmethod
|
||||
async def search_transcripts(
|
||||
cls, params: SearchParameters
|
||||
) -> tuple[list[SearchResult], int]:
|
||||
"""
|
||||
Full-text search for transcripts using PostgreSQL tsvector.
|
||||
Returns (results, total_count).
|
||||
"""
|
||||
|
||||
if not is_postgresql():
|
||||
logger.warning(
|
||||
"Full-text search requires PostgreSQL. Returning empty results."
|
||||
)
|
||||
return [], 0
|
||||
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
|
||||
base_query = sqlalchemy.select(
|
||||
[
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank"),
|
||||
]
|
||||
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
||||
|
||||
if params.user_id:
|
||||
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||
if params.room_id:
|
||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||
|
||||
query = (
|
||||
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
||||
.limit(params.limit)
|
||||
.offset(params.offset)
|
||||
)
|
||||
rs = await database.fetch_all(query)
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
base_query.alias("search_results")
|
||||
)
|
||||
total = await database.fetch_val(count_query)
|
||||
|
||||
def _process_result(r) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
webvtt: str | None = r_dict.pop("webvtt", None)
|
||||
db_result = SearchResultDB.model_validate(r_dict)
|
||||
|
||||
snippets = []
|
||||
if webvtt:
|
||||
plain_text = cls._extract_webvtt_text(webvtt)
|
||||
snippets = cls._generate_snippets(plain_text, params.query_text)
|
||||
|
||||
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
||||
|
||||
results = [_process_result(r) for r in rs]
|
||||
return results, total
|
||||
|
||||
|
||||
search_controller = SearchController()
|
||||
@@ -1,9 +1,10 @@
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -11,13 +12,19 @@ import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
from sqlalchemy.sql import false, or_
|
||||
|
||||
from reflector.db import database, metadata
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils import generate_uuid4
|
||||
from reflector.utils.webvtt import topics_to_webvtt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SourceKind(enum.StrEnum):
|
||||
@@ -76,6 +83,7 @@ transcripts = sqlalchemy.Table(
|
||||
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
|
||||
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
||||
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
@@ -83,6 +91,29 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
)
|
||||
|
||||
# Add PostgreSQL-specific full-text search column
|
||||
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
|
||||
if is_postgresql():
|
||||
transcripts.append_column(
|
||||
sqlalchemy.Column(
|
||||
"search_vector_en",
|
||||
TSVECTOR,
|
||||
sqlalchemy.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Add GIN index for the search vector
|
||||
transcripts.append_constraint(
|
||||
sqlalchemy.Index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"search_vector_en",
|
||||
postgresql_using="gin",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_transcript_name() -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -147,14 +178,18 @@ class TranscriptParticipant(BaseModel):
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
"""Full transcript model with all fields."""
|
||||
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
title: str | None = None
|
||||
source_kind: SourceKind
|
||||
room_id: str | None = None
|
||||
locked: bool = False
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
@@ -168,9 +203,8 @@ class Transcript(BaseModel):
|
||||
meeting_id: str | None = None
|
||||
recording_id: str | None = None
|
||||
zulip_message_id: int | None = None
|
||||
source_kind: SourceKind
|
||||
audio_deleted: bool | None = None
|
||||
room_id: str | None = None
|
||||
webvtt: str | None = None
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -271,10 +305,12 @@ class Transcript(BaseModel):
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
# from the diarization processor
|
||||
from datetime import timedelta
|
||||
|
||||
from reflector.app import app
|
||||
from reflector.views.transcripts import create_access_token
|
||||
# TODO don't import app in db
|
||||
from reflector.app import app # noqa: PLC0415
|
||||
|
||||
# TODO a util + don''t import views in db
|
||||
from reflector.views.transcripts import create_access_token # noqa: PLC0415
|
||||
|
||||
path = app.url_path_for(
|
||||
"transcript_get_audio_mp3",
|
||||
@@ -335,7 +371,6 @@ class TranscriptController:
|
||||
- `room_id`: filter transcripts by room ID
|
||||
- `search_term`: filter transcripts by search term
|
||||
"""
|
||||
from reflector.db.rooms import rooms
|
||||
|
||||
query = transcripts.select().join(
|
||||
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
|
||||
@@ -502,10 +537,17 @@ class TranscriptController:
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
|
||||
async def update(self, transcript: Transcript, values: dict, mutate=True):
|
||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||
# using mutate=True is discouraged
|
||||
async def update(
|
||||
self, transcript: Transcript, values: dict, mutate=False
|
||||
) -> Transcript:
|
||||
"""
|
||||
Update a transcript fields with key/values in values
|
||||
Update a transcript fields with key/values in values.
|
||||
Returns a copy of the transcript with updated values.
|
||||
"""
|
||||
values = TranscriptController._handle_topics_update(values)
|
||||
|
||||
query = (
|
||||
transcripts.update()
|
||||
.where(transcripts.c.id == transcript.id)
|
||||
@@ -516,6 +558,28 @@ class TranscriptController:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
updated_transcript = transcript.model_copy(update=values)
|
||||
return updated_transcript
|
||||
|
||||
@staticmethod
|
||||
def _handle_topics_update(values: dict) -> dict:
|
||||
"""Auto-update WebVTT when topics are updated."""
|
||||
|
||||
if values.get("webvtt") is not None:
|
||||
logger.warn("trying to update read-only webvtt column")
|
||||
pass
|
||||
|
||||
topics_data = values.get("topics")
|
||||
if topics_data is None:
|
||||
return values
|
||||
|
||||
return {
|
||||
**values,
|
||||
"webvtt": topics_to_webvtt(
|
||||
[TranscriptTopic(**topic_dict) for topic_dict in topics_data]
|
||||
),
|
||||
}
|
||||
|
||||
async def remove_by_id(
|
||||
self,
|
||||
transcript_id: str,
|
||||
@@ -558,11 +622,7 @@ class TranscriptController:
|
||||
Append an event to a transcript
|
||||
"""
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"events": transcript.events_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
await self.update(transcript, {"events": transcript.events_dump()})
|
||||
return resp
|
||||
|
||||
async def upsert_topic(
|
||||
@@ -574,11 +634,7 @@ class TranscriptController:
|
||||
Upsert topics to a transcript
|
||||
"""
|
||||
transcript.upsert_topic(topic)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"topics": transcript.topics_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||
|
||||
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||
"""
|
||||
@@ -603,7 +659,8 @@ class TranscriptController:
|
||||
)
|
||||
|
||||
# indicate on the transcript that the audio is now on storage
|
||||
await self.update(transcript, {"audio_location": "storage"})
|
||||
# mutates transcript argument
|
||||
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
|
||||
|
||||
# unlink the local file
|
||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||
@@ -627,11 +684,7 @@ class TranscriptController:
|
||||
Add/update a participant to a transcript
|
||||
"""
|
||||
result = transcript.upsert_participant(participant)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"participants": transcript.participants_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||
return result
|
||||
|
||||
async def delete_participant(
|
||||
@@ -643,11 +696,7 @@ class TranscriptController:
|
||||
Delete a participant from a transcript
|
||||
"""
|
||||
transcript.delete_participant(participant_id)
|
||||
await self.update(
|
||||
transcript,
|
||||
{"participants": transcript.participants_dump()},
|
||||
mutate=False,
|
||||
)
|
||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
|
||||
7
server/reflector/db/utils.py
Normal file
7
server/reflector/db/utils.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Database utility functions."""
|
||||
|
||||
from reflector.db import database
|
||||
|
||||
|
||||
def is_postgresql() -> bool:
|
||||
return database.url.scheme and database.url.scheme.startswith('postgresql')
|
||||
@@ -14,12 +14,15 @@ It is directly linked to our data model.
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Generic
|
||||
|
||||
import av
|
||||
import boto3
|
||||
from celery import chord, current_task, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
@@ -35,7 +38,7 @@ from reflector.db.transcripts import (
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioDiarizationAutoProcessor,
|
||||
@@ -69,8 +72,6 @@ def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
async def run_with_db():
|
||||
from reflector.db import database
|
||||
|
||||
await database.connect()
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
@@ -144,7 +145,7 @@ class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner):
|
||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
@@ -164,7 +165,11 @@ class PipelineMainBase(PipelineRunner):
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
|
||||
@staticmethod
|
||||
def wrap_transcript_topics(
|
||||
topics: list[TranscriptTopic],
|
||||
) -> list[TitleSummaryWithIdProcessorType]:
|
||||
# transformation to a pipe-supported format
|
||||
return [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
@@ -174,7 +179,7 @@ class PipelineMainBase(PipelineRunner):
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
for topic in topics
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -380,7 +385,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
pipeline_post(transcript_id=self.transcript_id)
|
||||
|
||||
|
||||
class PipelineMainDiarization(PipelineMainBase):
|
||||
class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
"""
|
||||
Diarize the audio and update topics
|
||||
"""
|
||||
@@ -404,11 +409,10 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
pipeline.logger.info("Audio is local, skipping diarization")
|
||||
return
|
||||
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
audio_url = await transcript.get_audio_url()
|
||||
audio_diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url,
|
||||
topics=topics,
|
||||
topics=self.wrap_transcript_topics(transcript.topics),
|
||||
)
|
||||
|
||||
# as tempting to use pipeline.push, prefer to use the runner
|
||||
@@ -421,7 +425,7 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainFromTopics(PipelineMainBase):
|
||||
class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
||||
"""
|
||||
Pseudo class for generating a pipeline from topics
|
||||
"""
|
||||
@@ -443,7 +447,7 @@ class PipelineMainFromTopics(PipelineMainBase):
|
||||
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
|
||||
|
||||
# push topics
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
topics = PipelineMainBase.wrap_transcript_topics(transcript.topics)
|
||||
for topic in topics:
|
||||
await self.push(topic)
|
||||
|
||||
@@ -524,8 +528,6 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
# Convert to mp3
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
|
||||
import av
|
||||
|
||||
with av.open(wav_filename.as_posix()) as in_container:
|
||||
in_stream = in_container.streams.audio[0]
|
||||
with av.open(mp3_filename.as_posix(), "w") as out_container:
|
||||
@@ -604,7 +606,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
meeting.id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get fetch consent: {e}")
|
||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||
consent_denied = True
|
||||
|
||||
if not consent_denied:
|
||||
@@ -627,7 +629,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
f"Deleted original Whereby recording: {recording.bucket_name}/{recording.object_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete Whereby recording: {e}")
|
||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||
|
||||
# non-transactional, files marked for deletion not actually deleted is possible
|
||||
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||
@@ -640,7 +642,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
f"Deleted processed audio from storage: {transcript.storage_audio_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete processed audio: {e}")
|
||||
logger.error(f"Failed to delete processed audio: {e}", exc_info=e)
|
||||
|
||||
# 3. Delete local audio files
|
||||
try:
|
||||
@@ -649,7 +651,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||
if hasattr(transcript, "audio_wav_filename") and transcript.audio_wav_filename:
|
||||
transcript.audio_wav_filename.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete local audio files: {e}")
|
||||
logger.error(f"Failed to delete local audio files: {e}", exc_info=e)
|
||||
|
||||
logger.info("Consent cleanup done")
|
||||
|
||||
@@ -794,8 +796,6 @@ def pipeline_post(*, transcript_id: str):
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
import av
|
||||
|
||||
try:
|
||||
if transcript.audio_location == "storage":
|
||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||
|
||||
@@ -16,14 +16,17 @@ During its lifecycle, it will emit the following status:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
PipelineMessage = TypeVar("PipelineMessage")
|
||||
|
||||
class PipelineRunner(BaseModel):
|
||||
|
||||
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
@@ -67,7 +70,7 @@ class PipelineRunner(BaseModel):
|
||||
coro = self.run()
|
||||
asyncio.run(coro)
|
||||
|
||||
async def push(self, data):
|
||||
async def push(self, data: PipelineMessage):
|
||||
"""
|
||||
Push data to the pipeline
|
||||
"""
|
||||
@@ -92,7 +95,11 @@ class PipelineRunner(BaseModel):
|
||||
pass
|
||||
|
||||
async def _add_cmd(
|
||||
self, cmd: str, data, max_retries: int = 3, retry_time_limit: int = 3
|
||||
self,
|
||||
cmd: str,
|
||||
data: PipelineMessage,
|
||||
max_retries: int = 3,
|
||||
retry_time_limit: int = 3,
|
||||
):
|
||||
"""
|
||||
Enqueue a command to be executed in the runner.
|
||||
@@ -143,7 +150,10 @@ class PipelineRunner(BaseModel):
|
||||
cmd, data = await self._q_cmd.get()
|
||||
func = getattr(self, f"cmd_{cmd.lower()}")
|
||||
if func:
|
||||
await func(data)
|
||||
if cmd.upper() == "FLUSH":
|
||||
await func()
|
||||
else:
|
||||
await func(data)
|
||||
else:
|
||||
raise Exception(f"Unknown command {cmd}")
|
||||
except Exception:
|
||||
@@ -152,13 +162,13 @@ class PipelineRunner(BaseModel):
|
||||
self._ev_done.set()
|
||||
raise
|
||||
|
||||
async def cmd_push(self, data):
|
||||
async def cmd_push(self, data: PipelineMessage):
|
||||
if self._is_first_push:
|
||||
await self._set_status("push")
|
||||
self._is_first_push = False
|
||||
await self.pipeline.push(data)
|
||||
|
||||
async def cmd_flush(self, data):
|
||||
async def cmd_flush(self):
|
||||
await self._set_status("flush")
|
||||
await self.pipeline.flush()
|
||||
await self._set_status("ended")
|
||||
|
||||
@@ -2,9 +2,10 @@ import io
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
@@ -48,20 +49,70 @@ class AudioFile(BaseModel):
|
||||
self._path.unlink()
|
||||
|
||||
|
||||
# non-negative seconds with float part
|
||||
Seconds = Annotated[float, Field(ge=0.0, description="Time in seconds with float part")]
|
||||
|
||||
|
||||
class Word(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
start: Seconds
|
||||
end: Seconds
|
||||
speaker: int = 0
|
||||
|
||||
|
||||
class TranscriptSegment(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
start: Seconds
|
||||
end: Seconds
|
||||
speaker: int = 0
|
||||
|
||||
|
||||
def words_to_segments(words: list[Word]) -> list[TranscriptSegment]:
|
||||
# from a list of word, create a list of segments
|
||||
# join the word that are less than 2 seconds apart
|
||||
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
|
||||
segments = []
|
||||
current_segment = None
|
||||
MAX_SEGMENT_LENGTH = 120
|
||||
|
||||
for word in words:
|
||||
if current_segment is None:
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# If the word is attach to another speaker, push the current segment
|
||||
# and start a new one
|
||||
if word.speaker != current_segment.speaker:
|
||||
segments.append(current_segment)
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# if the word is the end of a sentence, and we have enough content,
|
||||
# add the word to the current segment and push it
|
||||
current_segment.text += word.text
|
||||
current_segment.end = word.end
|
||||
|
||||
have_punc = PUNC_RE.search(word.text)
|
||||
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
|
||||
segments.append(current_segment)
|
||||
current_segment = None
|
||||
|
||||
if current_segment:
|
||||
segments.append(current_segment)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
translation: str | None = None
|
||||
words: list[Word] = None
|
||||
@@ -117,49 +168,7 @@ class Transcript(BaseModel):
|
||||
return Transcript(text=self.text, translation=self.translation, words=words)
|
||||
|
||||
def as_segments(self) -> list[TranscriptSegment]:
|
||||
# from a list of word, create a list of segments
|
||||
# join the word that are less than 2 seconds apart
|
||||
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
|
||||
segments = []
|
||||
current_segment = None
|
||||
MAX_SEGMENT_LENGTH = 120
|
||||
|
||||
for word in self.words:
|
||||
if current_segment is None:
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# If the word is attach to another speaker, push the current segment
|
||||
# and start a new one
|
||||
if word.speaker != current_segment.speaker:
|
||||
segments.append(current_segment)
|
||||
current_segment = TranscriptSegment(
|
||||
text=word.text,
|
||||
start=word.start,
|
||||
end=word.end,
|
||||
speaker=word.speaker,
|
||||
)
|
||||
continue
|
||||
|
||||
# if the word is the end of a sentence, and we have enough content,
|
||||
# add the word to the current segment and push it
|
||||
current_segment.text += word.text
|
||||
current_segment.end = word.end
|
||||
|
||||
have_punc = PUNC_RE.search(word.text)
|
||||
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
|
||||
segments.append(current_segment)
|
||||
current_segment = None
|
||||
|
||||
if current_segment:
|
||||
segments.append(current_segment)
|
||||
|
||||
return segments
|
||||
return words_to_segments(self.words)
|
||||
|
||||
|
||||
class TitleSummary(BaseModel):
|
||||
|
||||
63
server/reflector/utils/webvtt.py
Normal file
63
server/reflector/utils/webvtt.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""WebVTT utilities for generating subtitle files from transcript data."""
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import webvtt
|
||||
|
||||
from reflector.processors.types import Seconds, Word, words_to_segments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflector.db.transcripts import TranscriptTopic
|
||||
|
||||
VttTimestamp = Annotated[str, "vtt_timestamp"]
|
||||
WebVTTStr = Annotated[str, "webvtt_str"]
|
||||
|
||||
|
||||
def _seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
|
||||
# lib doesn't do that
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
milliseconds = int((seconds % 1) * 1000)
|
||||
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
def words_to_webvtt(words: list[Word]) -> WebVTTStr:
|
||||
"""Convert words to WebVTT using existing segmentation logic."""
|
||||
vtt = webvtt.WebVTT()
|
||||
if not words:
|
||||
return vtt.content
|
||||
|
||||
segments = words_to_segments(words)
|
||||
|
||||
for segment in segments:
|
||||
text = segment.text.strip()
|
||||
# lib doesn't do that
|
||||
text = f"<v Speaker{segment.speaker}>{text}"
|
||||
|
||||
caption = webvtt.Caption(
|
||||
start=_seconds_to_timestamp(segment.start),
|
||||
end=_seconds_to_timestamp(segment.end),
|
||||
text=text,
|
||||
)
|
||||
vtt.captions.append(caption)
|
||||
|
||||
return vtt.content
|
||||
|
||||
|
||||
def topics_to_webvtt(topics: list["TranscriptTopic"]) -> WebVTTStr:
|
||||
if not topics:
|
||||
return webvtt.WebVTT().content
|
||||
|
||||
all_words: list[Word] = []
|
||||
for topic in topics:
|
||||
all_words.extend(topic.words)
|
||||
|
||||
# assert it's in sequence
|
||||
for i in range(len(all_words) - 1):
|
||||
assert (
|
||||
all_words[i].start <= all_words[i + 1].start
|
||||
), f"Words are not in sequence: {all_words[i].text} and {all_words[i + 1].text} are not consecutive: {all_words[i].start} > {all_words[i + 1].start}"
|
||||
|
||||
return words_to_webvtt(all_words)
|
||||
@@ -1,15 +1,29 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import paginate
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.search import (
|
||||
DEFAULT_SEARCH_LIMIT,
|
||||
SearchLimit,
|
||||
SearchLimitBase,
|
||||
SearchOffset,
|
||||
SearchOffsetBase,
|
||||
SearchParameters,
|
||||
SearchQuery,
|
||||
SearchQueryBase,
|
||||
SearchResult,
|
||||
SearchTotal,
|
||||
search_controller,
|
||||
)
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
TranscriptParticipant,
|
||||
@@ -100,6 +114,21 @@ class DeletionStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")]
|
||||
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
|
||||
SearchOffsetParam = Annotated[
|
||||
SearchOffsetBase, Query(description="Number of results to skip")
|
||||
]
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
results: list[SearchResult]
|
||||
total: SearchTotal
|
||||
query: SearchQuery
|
||||
limit: SearchLimit
|
||||
offset: SearchOffset
|
||||
|
||||
|
||||
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
|
||||
async def transcripts_list(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
@@ -107,8 +136,6 @@ async def transcripts_list(
|
||||
room_id: str | None = None,
|
||||
search_term: str | None = None,
|
||||
):
|
||||
from reflector.db import database
|
||||
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
@@ -127,6 +154,39 @@ async def transcripts_list(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transcripts/search", response_model=SearchResponse)
|
||||
async def transcripts_search(
|
||||
q: SearchQueryParam,
|
||||
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
|
||||
offset: SearchOffsetParam = 0,
|
||||
room_id: Optional[str] = None,
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Full-text search across transcript titles and content.
|
||||
"""
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
search_params = SearchParameters(
|
||||
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
|
||||
)
|
||||
|
||||
results, total = await search_controller.search_transcripts(search_params)
|
||||
|
||||
return SearchResponse(
|
||||
results=results,
|
||||
total=total,
|
||||
query=search_params.query_text,
|
||||
limit=search_params.limit,
|
||||
offset=search_params.offset,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transcripts", response_model=GetTranscript)
|
||||
async def transcripts_create(
|
||||
info: CreateTranscript,
|
||||
@@ -273,8 +333,8 @@ async def transcript_update(
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = info.dict(exclude_unset=True)
|
||||
await transcripts_controller.update(transcript, values)
|
||||
return transcript
|
||||
updated_transcript = await transcripts_controller.update(transcript, values)
|
||||
return updated_transcript
|
||||
|
||||
|
||||
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
|
||||
|
||||
163
server/tests/test_search.py
Normal file
163
server/tests/test_search.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for full-text search functionality."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_postgresql_only():
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
params = SearchParameters(query_text="any query here")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" ")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_input_validation():
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
try:
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" \t\n ")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgresql_search_with_data():
|
||||
"""Test full-text search with actual data in PostgreSQL.
|
||||
|
||||
Example how to run: DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector_test uv run pytest tests/test_search.py::test_postgresql_search_with_data -v -p no:env
|
||||
"""
|
||||
# Skip if not PostgreSQL
|
||||
if not is_postgresql():
|
||||
pytest.skip("Test requires PostgreSQL. Set DATABASE_URL=postgresql://...")
|
||||
|
||||
await database.connect()
|
||||
|
||||
# collision is improbable
|
||||
test_id = "test-search-e2e-7f3a9b2c"
|
||||
|
||||
try:
|
||||
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Search Transcript",
|
||||
"title": "Engineering Planning Meeting Q4 2024",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(),
|
||||
"short_summary": "Team discussed search implementation",
|
||||
"long_summary": "The engineering team met to plan the search feature",
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Welcome to our engineering planning meeting for Q4 2024.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
Today we'll discuss the implementation of full-text search.
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
The search feature should support complex queries with ranking.
|
||||
|
||||
00:00:30.000 --> 00:00:40.000
|
||||
We need to implement PostgreSQL tsvector for better performance.""",
|
||||
}
|
||||
|
||||
await database.execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Test 1: Search for a word in title
|
||||
params = SearchParameters(query_text="planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by title word"
|
||||
|
||||
# Test 2: Search for a word in webvtt content
|
||||
params = SearchParameters(query_text="tsvector")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by webvtt content"
|
||||
|
||||
# Test 3: Search with multiple words
|
||||
params = SearchParameters(query_text="engineering planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by multiple words"
|
||||
|
||||
# Test 4: Verify SearchResult structure
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
if test_result:
|
||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||
assert test_result.status == "completed"
|
||||
assert test_result.duration == 1800.0
|
||||
assert test_result.source_kind == "room"
|
||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||
|
||||
# Test 5: Search with OR operator
|
||||
params = SearchParameters(query_text="tsvector OR nosuchword")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript with OR query"
|
||||
|
||||
# Test 6: Quoted phrase search
|
||||
params = SearchParameters(query_text='"full-text search"')
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by exact phrase"
|
||||
|
||||
finally:
|
||||
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
|
||||
await database.disconnect()
|
||||
198
server/tests/test_search_snippets.py
Normal file
198
server/tests/test_search_snippets.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Unit tests for search snippet generation."""
|
||||
|
||||
from reflector.db.search import SearchController
|
||||
|
||||
|
||||
class TestExtractWebVTT:
|
||||
"""Test WebVTT text extraction."""
|
||||
|
||||
def test_extract_webvtt_with_speakers(self):
|
||||
"""Test extraction removes speaker tags and timestamps."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
<v Speaker0>Hello world, this is a test.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
||||
"""
|
||||
result = SearchController._extract_webvtt_text(webvtt)
|
||||
assert "Hello world, this is a test" in result
|
||||
assert "Indeed it is a test" in result
|
||||
assert "<v Speaker" not in result
|
||||
assert "00:00" not in result
|
||||
assert "-->" not in result
|
||||
|
||||
def test_extract_empty_webvtt(self):
|
||||
"""Test empty WebVTT returns empty string."""
|
||||
assert SearchController._extract_webvtt_text("") == ""
|
||||
assert SearchController._extract_webvtt_text(None) == ""
|
||||
|
||||
def test_extract_malformed_webvtt(self):
|
||||
"""Test malformed WebVTT returns empty string."""
|
||||
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestGenerateSnippets:
|
||||
"""Test snippet generation from plain text."""
|
||||
|
||||
def test_multiple_matches(self):
|
||||
"""Test finding multiple occurrences of search term in long text."""
|
||||
# Create text with Python mentions far apart to get separate snippets
|
||||
separator = " This is filler text. " * 20 # ~400 chars of padding
|
||||
text = (
|
||||
"Python is great for machine learning."
|
||||
+ separator
|
||||
+ "Many companies use Python for data science."
|
||||
+ separator
|
||||
+ "Python has excellent libraries for analysis."
|
||||
+ separator
|
||||
+ "The Python community is very supportive."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "Python")
|
||||
# With enough separation, we should get multiple snippets
|
||||
assert len(snippets) >= 2 # At least 2 distinct snippets
|
||||
|
||||
# Each snippet should contain "Python"
|
||||
for snippet in snippets:
|
||||
assert "python" in snippet.lower()
|
||||
|
||||
def test_single_match(self):
|
||||
"""Test single occurrence returns one snippet."""
|
||||
text = "This document discusses artificial intelligence and its applications."
|
||||
snippets = SearchController._generate_snippets(text, "artificial intelligence")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "artificial intelligence" in snippets[0].lower()
|
||||
|
||||
def test_no_matches(self):
|
||||
"""Test no matches returns empty list."""
|
||||
text = "This is some random text without the search term."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
assert snippets == []
|
||||
|
||||
def test_case_insensitive_search(self):
|
||||
"""Test search is case insensitive."""
|
||||
# Add enough text between matches to get separate snippets
|
||||
text = (
|
||||
"MACHINE LEARNING is important for modern applications. "
|
||||
+ "It requires lots of data and computational resources. " * 5 # Padding
|
||||
+ "Machine Learning rocks and transforms industries. "
|
||||
+ "Deep learning is a subset of it. " * 5 # More padding
|
||||
+ "Finally, machine learning will shape our future."
|
||||
)
|
||||
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
# Should find at least 2 (might be 3 if text is long enough)
|
||||
assert len(snippets) >= 2
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
|
||||
def test_partial_match_fallback(self):
|
||||
"""Test fallback to first word when exact phrase not found."""
|
||||
text = "We use machine intelligence for processing."
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
# Should fall back to finding "machine"
|
||||
assert len(snippets) == 1
|
||||
assert "machine" in snippets[0].lower()
|
||||
|
||||
def test_snippet_ellipsis(self):
|
||||
"""Test ellipsis added for truncated snippets."""
|
||||
# Long text where match is in the middle
|
||||
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
||||
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "..." in snippets[0] # Should have ellipsis
|
||||
assert "TARGET_WORD" in snippets[0]
|
||||
|
||||
def test_overlapping_snippets_deduplicated(self):
|
||||
"""Test overlapping matches don't create duplicate snippets."""
|
||||
text = "test test test word" * 10 # Repeated pattern
|
||||
snippets = SearchController._generate_snippets(text, "test")
|
||||
|
||||
# Should get unique snippets, not duplicates
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) == len(set(snippets)) # All unique
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test empty text or search term returns empty list."""
|
||||
assert SearchController._generate_snippets("", "search") == []
|
||||
assert SearchController._generate_snippets("text", "") == []
|
||||
assert SearchController._generate_snippets("", "") == []
|
||||
|
||||
def test_max_snippets_limit(self):
|
||||
"""Test respects max_snippets parameter."""
|
||||
# Create text with well-separated occurrences
|
||||
separator = " filler " * 50 # Ensure snippets don't overlap
|
||||
text = ("Python is amazing" + separator) * 10 # 10 occurrences
|
||||
|
||||
# Test with different limits
|
||||
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
|
||||
assert len(snippets_1) == 1
|
||||
|
||||
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
|
||||
assert len(snippets_2) == 2
|
||||
|
||||
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
|
||||
|
||||
def test_snippet_length(self):
|
||||
"""Test snippet length is reasonable."""
|
||||
text = "word " * 200 # Long text
|
||||
snippets = SearchController._generate_snippets(text, "word")
|
||||
|
||||
for snippet in snippets:
|
||||
# Default max_length is 150 + some context
|
||||
assert len(snippet) <= 200 # Some buffer for ellipsis
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
"""Test the complete WebVTT to snippets pipeline."""
|
||||
|
||||
def test_webvtt_to_snippets_integration(self):
|
||||
"""Test full pipeline from WebVTT to search snippets."""
|
||||
# Create WebVTT with well-separated content for multiple snippets
|
||||
webvtt = (
|
||||
"""WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
<v Speaker0>Let's discuss machine learning applications in modern technology.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>"""
|
||||
+ "Various industries are adopting new technologies. " * 10
|
||||
+ """
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
<v Speaker2>Machine learning is revolutionizing healthcare and diagnostics.
|
||||
|
||||
00:00:30.000 --> 00:00:40.000
|
||||
<v Speaker3>"""
|
||||
+ "Financial markets show interesting patterns. " * 10
|
||||
+ """
|
||||
|
||||
00:00:40.000 --> 00:00:50.000
|
||||
<v Speaker0>Machine learning in education provides personalized experiences.
|
||||
"""
|
||||
)
|
||||
|
||||
# Extract and generate snippets
|
||||
plain_text = SearchController._extract_webvtt_text(webvtt)
|
||||
snippets = SearchController._generate_snippets(plain_text, "machine learning")
|
||||
|
||||
# Should find at least 2 snippets (text might still be close together)
|
||||
assert len(snippets) >= 1 # At minimum one snippet containing matches
|
||||
assert len(snippets) <= 3 # At most 3 by default
|
||||
|
||||
# No WebVTT artifacts in snippets
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
assert "<v Speaker" not in snippet
|
||||
assert "00:00" not in snippet
|
||||
assert "-->" not in snippet
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
@@ -39,14 +40,18 @@ async def test_transcript_process(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish
|
||||
while True:
|
||||
# wait for processing to finish (max 10 minutes)
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
|
||||
|
||||
# restart the processing
|
||||
response = await ac.post(
|
||||
@@ -55,14 +60,18 @@ async def test_transcript_process(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish
|
||||
while True:
|
||||
# wait for processing to finish (max 10 minutes)
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -21,14 +22,31 @@ class ThreadedUvicorn:
|
||||
|
||||
async def start(self):
|
||||
self.thread.start()
|
||||
while not self.server.started:
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (
|
||||
not self.server.started
|
||||
and (time.monotonic() - start_time) < timeout_seconds
|
||||
):
|
||||
await asyncio.sleep(0.1)
|
||||
if not self.server.started:
|
||||
raise TimeoutError(
|
||||
f"Server failed to start after {timeout_seconds} seconds"
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
if self.thread.is_alive():
|
||||
self.server.should_exit = True
|
||||
while self.thread.is_alive():
|
||||
continue
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.time()
|
||||
while (
|
||||
self.thread.is_alive() and (time.time() - start_time) < timeout_seconds
|
||||
):
|
||||
time.sleep(0.1)
|
||||
if self.thread.is_alive():
|
||||
raise TimeoutError(
|
||||
f"Thread failed to stop after {timeout_seconds} seconds"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -92,12 +110,16 @@ async def test_transcript_rtc_and_websocket(
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
|
||||
print("Test websocket: CONNECTED")
|
||||
try:
|
||||
while True:
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
msg = await ws.receive_json()
|
||||
print(f"Test websocket: JSON {msg}")
|
||||
if msg is None:
|
||||
break
|
||||
events.append(msg)
|
||||
else:
|
||||
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
|
||||
except Exception as e:
|
||||
print(f"Test websocket: EXCEPTION {e}")
|
||||
finally:
|
||||
@@ -145,9 +167,12 @@ async def test_transcript_rtc_and_websocket(
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
raise TimeoutError("Transcript processing failed")
|
||||
|
||||
# stop websocket task
|
||||
websocket_task.cancel()
|
||||
@@ -253,12 +278,16 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
|
||||
print("Test websocket: CONNECTED")
|
||||
try:
|
||||
while True:
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
msg = await ws.receive_json()
|
||||
print(f"Test websocket: JSON {msg}")
|
||||
if msg is None:
|
||||
break
|
||||
events.append(msg)
|
||||
else:
|
||||
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
|
||||
except Exception as e:
|
||||
print(f"Test websocket: EXCEPTION {e}")
|
||||
finally:
|
||||
@@ -310,9 +339,12 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
if resp.json()["status"] == "ended":
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
|
||||
if resp.json()["status"] != "ended":
|
||||
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||
raise TimeoutError("Transcript processing failed")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
@@ -39,14 +40,18 @@ async def test_transcript_upload_file(
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait the processing to finish
|
||||
while True:
|
||||
# wait the processing to finish (max 10 minutes)
|
||||
timeout_seconds = 600 # 10 minutes
|
||||
start_time = time.monotonic()
|
||||
while (time.monotonic() - start_time) < timeout_seconds:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
|
||||
151
server/tests/test_webvtt.py
Normal file
151
server/tests/test_webvtt.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for WebVTT utilities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.processors.types import Transcript, Word, words_to_segments
|
||||
from reflector.utils.webvtt import topics_to_webvtt, words_to_webvtt
|
||||
|
||||
|
||||
class TestWordsToWebVTT:
|
||||
"""Test words_to_webvtt function with TDD approach."""
|
||||
|
||||
def test_empty_words_returns_empty_webvtt(self):
|
||||
"""Should return empty WebVTT structure for empty words list."""
|
||||
|
||||
result = words_to_webvtt([])
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert result.strip() == "WEBVTT"
|
||||
|
||||
def test_single_word_creates_single_caption(self):
|
||||
"""Should create one caption for a single word."""
|
||||
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert "00:00:00.000 --> 00:00:01.000" in result
|
||||
assert "Hello" in result
|
||||
assert "<v Speaker0>" in result
|
||||
|
||||
def test_multiple_words_same_speaker_groups_properly(self):
|
||||
"""Should group consecutive words from same speaker."""
|
||||
|
||||
words = [
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" world", start=0.5, end=1.0, speaker=0),
|
||||
]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert "Hello world" in result
|
||||
assert "<v Speaker0>" in result
|
||||
|
||||
def test_speaker_change_creates_new_caption(self):
|
||||
"""Should create new caption when speaker changes."""
|
||||
|
||||
words = [
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text="Hi", start=0.6, end=1.0, speaker=1),
|
||||
]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
lines = result.split("\n")
|
||||
assert "<v Speaker0>" in result
|
||||
assert "<v Speaker1>" in result
|
||||
assert "Hello" in result
|
||||
assert "Hi" in result
|
||||
|
||||
def test_punctuation_creates_segment_boundary(self):
|
||||
"""Should respect punctuation boundaries from segmentation logic."""
|
||||
|
||||
words = [
|
||||
Word(text="Hello.", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" How", start=0.6, end=1.0, speaker=0),
|
||||
Word(text=" are", start=1.0, end=1.3, speaker=0),
|
||||
Word(text=" you?", start=1.3, end=1.8, speaker=0),
|
||||
]
|
||||
result = words_to_webvtt(words)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
assert "<v Speaker0>" in result
|
||||
|
||||
|
||||
class TestTopicsToWebVTT:
|
||||
"""Test topics_to_webvtt function."""
|
||||
|
||||
def test_empty_topics_returns_empty_webvtt(self):
|
||||
"""Should handle empty topics list."""
|
||||
|
||||
result = topics_to_webvtt([])
|
||||
assert "WEBVTT" in result
|
||||
assert result.strip() == "WEBVTT"
|
||||
|
||||
def test_extracts_words_from_topics(self):
|
||||
"""Should extract all words from topics in sequence."""
|
||||
|
||||
class MockTopic:
|
||||
def __init__(self, words):
|
||||
self.words = words
|
||||
|
||||
topics = [
|
||||
MockTopic(
|
||||
[
|
||||
Word(text="First", start=0.0, end=0.5, speaker=1),
|
||||
Word(text="Second", start=1.0, end=1.5, speaker=0),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
result = topics_to_webvtt(topics)
|
||||
|
||||
assert "WEBVTT" in result
|
||||
first_pos = result.find("First")
|
||||
second_pos = result.find("Second")
|
||||
assert first_pos < second_pos
|
||||
|
||||
def test_non_sequential_topics_raises_assertion(self):
|
||||
"""Should raise assertion error when words are not in chronological sequence."""
|
||||
|
||||
class MockTopic:
|
||||
def __init__(self, words):
|
||||
self.words = words
|
||||
|
||||
topics = [
|
||||
MockTopic(
|
||||
[
|
||||
Word(text="Second", start=1.0, end=1.5, speaker=0),
|
||||
Word(text="First", start=0.0, end=0.5, speaker=1),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
topics_to_webvtt(topics)
|
||||
|
||||
assert "Words are not in sequence" in str(exc_info.value)
|
||||
assert "Second and First" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestTranscriptWordsToSegments:
|
||||
"""Test static words_to_segments method (TDD for making it static)."""
|
||||
|
||||
def test_static_method_exists(self):
|
||||
"""Should have static words_to_segments method."""
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
segments = words_to_segments(words)
|
||||
|
||||
assert isinstance(segments, list)
|
||||
assert len(segments) == 1
|
||||
assert segments[0].text == "Hello"
|
||||
assert segments[0].speaker == 0
|
||||
|
||||
def test_backward_compatibility(self):
|
||||
"""Should maintain backward compatibility with instance method."""
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
transcript = Transcript(words=words)
|
||||
|
||||
segments = transcript.as_segments()
|
||||
assert isinstance(segments, list)
|
||||
assert len(segments) == 1
|
||||
assert segments[0].text == "Hello"
|
||||
34
server/tests/test_webvtt_implementation.py
Normal file
34
server/tests/test_webvtt_implementation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Test WebVTT auto-update functionality and edge cases."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptController,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestWebVTTAutoUpdateImplementation:
|
||||
async def test_handle_topics_update_handles_dict_conversion(self):
|
||||
"""
|
||||
Verify that _handle_topics_update() properly converts dict data to TranscriptTopic objects.
|
||||
"""
|
||||
values = {
|
||||
"topics": [
|
||||
{
|
||||
"id": "topic1",
|
||||
"title": "Test",
|
||||
"summary": "Test",
|
||||
"timestamp": 0.0,
|
||||
"words": [
|
||||
{"text": "Hello", "start": 0.0, "end": 1.0, "speaker": 0}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
updated_values = TranscriptController._handle_topics_update(values)
|
||||
|
||||
assert "webvtt" in updated_values
|
||||
assert updated_values["webvtt"] is not None
|
||||
assert "WEBVTT" in updated_values["webvtt"]
|
||||
234
server/tests/test_webvtt_integration.py
Normal file
234
server/tests/test_webvtt_integration.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db import database
|
||||
from reflector.db.transcripts import (
|
||||
SourceKind,
|
||||
TranscriptController,
|
||||
TranscriptTopic,
|
||||
transcripts,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestWebVTTAutoUpdate:
|
||||
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
|
||||
|
||||
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self):
|
||||
"""WebVTT should be None when creating transcript without topics."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["webvtt"] is None
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_updated_on_upsert_topic(self):
|
||||
"""WebVTT should update when upserting topics via upsert_topic method."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Test Topic",
|
||||
summary="Test summary",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" world", start=0.5, end=1.0, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
await controller.upsert_topic(transcript, topic)
|
||||
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
assert "Hello world" in webvtt
|
||||
assert "<v Speaker0>" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_updated_on_direct_topics_update(self):
|
||||
"""WebVTT should update when updating topics field directly."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topics_data = [
|
||||
{
|
||||
"id": "topic1",
|
||||
"title": "First Topic",
|
||||
"summary": "First sentence test",
|
||||
"timestamp": 0.0,
|
||||
"words": [
|
||||
{"text": "First", "start": 0.0, "end": 0.5, "speaker": 0},
|
||||
{"text": " sentence", "start": 0.5, "end": 1.0, "speaker": 0},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
await controller.update(transcript, {"topics": topics_data})
|
||||
|
||||
# Fetch from DB
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
assert "First sentence" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_updated_manually_with_handle_topics_update(self):
|
||||
"""Test that _handle_topics_update works when called manually."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic1 = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Topic 1",
|
||||
summary="Manual test",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Manual", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" test", start=0.5, end=1.0, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
transcript.upsert_topic(topic1)
|
||||
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
await controller.update(transcript, values)
|
||||
|
||||
# Fetch from DB
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "WEBVTT" in webvtt
|
||||
assert "Manual test" in webvtt
|
||||
assert "<v Speaker0>" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_webvtt_update_with_non_sequential_topics_fails(self):
|
||||
"""Test that non-sequential topics raise assertion error."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic1 = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Bad Topic",
|
||||
summary="Bad order test",
|
||||
timestamp=1.0,
|
||||
words=[
|
||||
Word(text="Second", start=2.0, end=2.5, speaker=0),
|
||||
Word(text="First", start=1.0, end=1.5, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
transcript.upsert_topic(topic1)
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
TranscriptController._handle_topics_update(values)
|
||||
|
||||
assert "Words are not in sequence" in str(exc_info.value)
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
|
||||
async def test_multiple_speakers_in_webvtt(self):
|
||||
"""Test WebVTT generation with multiple speakers."""
|
||||
controller = TranscriptController()
|
||||
|
||||
transcript = await controller.add(
|
||||
name="Test Transcript",
|
||||
source_kind=SourceKind.FILE,
|
||||
)
|
||||
|
||||
try:
|
||||
topic = TranscriptTopic(
|
||||
id="topic1",
|
||||
title="Multi Speaker",
|
||||
summary="Multi speaker test",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text="Hi", start=1.0, end=1.5, speaker=1),
|
||||
Word(text="Goodbye", start=2.0, end=2.5, speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
transcript.upsert_topic(topic)
|
||||
values = {"topics": transcript.topics_dump()}
|
||||
|
||||
await controller.update(transcript, values)
|
||||
|
||||
# Fetch from DB
|
||||
result = await database.fetch_one(
|
||||
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
webvtt = result["webvtt"]
|
||||
|
||||
assert webvtt is not None
|
||||
assert "<v Speaker0>" in webvtt
|
||||
assert "<v Speaker1>" in webvtt
|
||||
assert "Hello" in webvtt
|
||||
assert "Hi" in webvtt
|
||||
assert "Goodbye" in webvtt
|
||||
|
||||
finally:
|
||||
await controller.remove_by_id(transcript.id)
|
||||
2717
server/uv.lock
generated
2717
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -209,7 +209,6 @@ export const $GetTranscript = {
|
||||
},
|
||||
created_at: {
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
title: "Created At",
|
||||
},
|
||||
share_mode: {
|
||||
@@ -395,7 +394,6 @@ export const $GetTranscriptMinimal = {
|
||||
},
|
||||
created_at: {
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
title: "Created At",
|
||||
},
|
||||
share_mode: {
|
||||
@@ -758,8 +756,15 @@ export const $Page_GetTranscriptMinimal_ = {
|
||||
title: "Items",
|
||||
},
|
||||
total: {
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
anyOf: [
|
||||
{
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Total",
|
||||
},
|
||||
page: {
|
||||
@@ -800,7 +805,7 @@ export const $Page_GetTranscriptMinimal_ = {
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: ["items", "total", "page", "size"],
|
||||
required: ["items", "page", "size"],
|
||||
title: "Page[GetTranscriptMinimal]",
|
||||
} as const;
|
||||
|
||||
@@ -814,8 +819,15 @@ export const $Page_Room_ = {
|
||||
title: "Items",
|
||||
},
|
||||
total: {
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
anyOf: [
|
||||
{
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Total",
|
||||
},
|
||||
page: {
|
||||
@@ -856,7 +868,7 @@ export const $Page_Room_ = {
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: ["items", "total", "page", "size"],
|
||||
required: ["items", "page", "size"],
|
||||
title: "Page[Room]",
|
||||
} as const;
|
||||
|
||||
@@ -973,6 +985,136 @@ export const $RtcOffer = {
|
||||
title: "RtcOffer",
|
||||
} as const;
|
||||
|
||||
export const $SearchResponse = {
|
||||
properties: {
|
||||
results: {
|
||||
items: {
|
||||
$ref: "#/components/schemas/SearchResult",
|
||||
},
|
||||
type: "array",
|
||||
title: "Results",
|
||||
},
|
||||
total: {
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
title: "Total",
|
||||
description: "Total number of search results",
|
||||
},
|
||||
query: {
|
||||
type: "string",
|
||||
minLength: 1,
|
||||
title: "Query",
|
||||
description: "Search query text",
|
||||
},
|
||||
limit: {
|
||||
type: "integer",
|
||||
maximum: 100,
|
||||
minimum: 1,
|
||||
title: "Limit",
|
||||
description: "Results per page",
|
||||
},
|
||||
offset: {
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
title: "Offset",
|
||||
description: "Number of results to skip",
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: ["results", "total", "query", "limit", "offset"],
|
||||
title: "SearchResponse",
|
||||
} as const;
|
||||
|
||||
export const $SearchResult = {
|
||||
properties: {
|
||||
id: {
|
||||
type: "string",
|
||||
minLength: 1,
|
||||
title: "Id",
|
||||
},
|
||||
title: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Title",
|
||||
},
|
||||
user_id: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "User Id",
|
||||
},
|
||||
room_id: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Room Id",
|
||||
},
|
||||
created_at: {
|
||||
type: "string",
|
||||
title: "Created At",
|
||||
},
|
||||
status: {
|
||||
type: "string",
|
||||
minLength: 1,
|
||||
title: "Status",
|
||||
},
|
||||
rank: {
|
||||
type: "number",
|
||||
maximum: 1,
|
||||
minimum: 0,
|
||||
title: "Rank",
|
||||
},
|
||||
duration: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "number",
|
||||
minimum: 0,
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Duration",
|
||||
description: "Duration in seconds",
|
||||
},
|
||||
search_snippets: {
|
||||
items: {
|
||||
type: "string",
|
||||
},
|
||||
type: "array",
|
||||
title: "Search Snippets",
|
||||
description: "Text snippets around search matches",
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: [
|
||||
"id",
|
||||
"created_at",
|
||||
"status",
|
||||
"rank",
|
||||
"duration",
|
||||
"search_snippets",
|
||||
],
|
||||
title: "SearchResult",
|
||||
description: "Public search result model with computed fields.",
|
||||
} as const;
|
||||
|
||||
export const $SourceKind = {
|
||||
type: "string",
|
||||
enum: ["room", "live", "file"],
|
||||
@@ -1397,6 +1539,7 @@ export const $WherebyWebhookEvent = {
|
||||
title: "Type",
|
||||
},
|
||||
data: {
|
||||
additionalProperties: true,
|
||||
type: "object",
|
||||
title: "Data",
|
||||
},
|
||||
@@ -1414,11 +1557,15 @@ export const $Word = {
|
||||
},
|
||||
start: {
|
||||
type: "number",
|
||||
minimum: 0,
|
||||
title: "Start",
|
||||
description: "Time in seconds with float part",
|
||||
},
|
||||
end: {
|
||||
type: "number",
|
||||
minimum: 0,
|
||||
title: "End",
|
||||
description: "Time in seconds with float part",
|
||||
},
|
||||
speaker: {
|
||||
type: "integer",
|
||||
|
||||
@@ -20,6 +20,8 @@ import type {
|
||||
V1TranscriptsListResponse,
|
||||
V1TranscriptsCreateData,
|
||||
V1TranscriptsCreateResponse,
|
||||
V1TranscriptsSearchData,
|
||||
V1TranscriptsSearchResponse,
|
||||
V1TranscriptGetData,
|
||||
V1TranscriptGetResponse,
|
||||
V1TranscriptUpdateData,
|
||||
@@ -276,6 +278,35 @@ export class DefaultService {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Transcripts Search
|
||||
* Full-text search across transcript titles and content.
|
||||
* @param data The data for the request.
|
||||
* @param data.q Search query text
|
||||
* @param data.limit Results per page
|
||||
* @param data.offset Number of results to skip
|
||||
* @param data.roomId
|
||||
* @returns SearchResponse Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public v1TranscriptsSearch(
|
||||
data: V1TranscriptsSearchData,
|
||||
): CancelablePromise<V1TranscriptsSearchResponse> {
|
||||
return this.httpRequest.request({
|
||||
method: "GET",
|
||||
url: "/v1/transcripts/search",
|
||||
query: {
|
||||
q: data.q,
|
||||
limit: data.limit,
|
||||
offset: data.offset,
|
||||
room_id: data.roomId,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Transcript Get
|
||||
* @param data The data for the request.
|
||||
|
||||
@@ -141,7 +141,7 @@ export type MeetingConsentRequest = {
|
||||
|
||||
export type Page_GetTranscriptMinimal_ = {
|
||||
items: Array<GetTranscriptMinimal>;
|
||||
total: number;
|
||||
total?: number | null;
|
||||
page: number | null;
|
||||
size: number | null;
|
||||
pages?: number | null;
|
||||
@@ -149,7 +149,7 @@ export type Page_GetTranscriptMinimal_ = {
|
||||
|
||||
export type Page_Room_ = {
|
||||
items: Array<Room>;
|
||||
total: number;
|
||||
total?: number | null;
|
||||
page: number | null;
|
||||
size: number | null;
|
||||
pages?: number | null;
|
||||
@@ -181,6 +181,47 @@ export type RtcOffer = {
|
||||
type: string;
|
||||
};
|
||||
|
||||
export type SearchResponse = {
|
||||
results: Array<SearchResult>;
|
||||
/**
|
||||
* Total number of search results
|
||||
*/
|
||||
total: number;
|
||||
/**
|
||||
* Search query text
|
||||
*/
|
||||
query: string;
|
||||
/**
|
||||
* Results per page
|
||||
*/
|
||||
limit: number;
|
||||
/**
|
||||
* Number of results to skip
|
||||
*/
|
||||
offset: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Public search result model with computed fields.
|
||||
*/
|
||||
export type SearchResult = {
|
||||
id: string;
|
||||
title?: string | null;
|
||||
user_id?: string | null;
|
||||
room_id?: string | null;
|
||||
created_at: string;
|
||||
status: string;
|
||||
rank: number;
|
||||
/**
|
||||
* Duration in seconds
|
||||
*/
|
||||
duration: number | null;
|
||||
/**
|
||||
* Text snippets around search matches
|
||||
*/
|
||||
search_snippets: Array<string>;
|
||||
};
|
||||
|
||||
export type SourceKind = "room" | "live" | "file";
|
||||
|
||||
export type SpeakerAssignment = {
|
||||
@@ -272,7 +313,13 @@ export type WherebyWebhookEvent = {
|
||||
|
||||
export type Word = {
|
||||
text: string;
|
||||
/**
|
||||
* Time in seconds with float part
|
||||
*/
|
||||
start: number;
|
||||
/**
|
||||
* Time in seconds with float part
|
||||
*/
|
||||
end: number;
|
||||
speaker?: number;
|
||||
};
|
||||
@@ -346,6 +393,24 @@ export type V1TranscriptsCreateData = {
|
||||
|
||||
export type V1TranscriptsCreateResponse = GetTranscript;
|
||||
|
||||
export type V1TranscriptsSearchData = {
|
||||
/**
|
||||
* Results per page
|
||||
*/
|
||||
limit?: number;
|
||||
/**
|
||||
* Number of results to skip
|
||||
*/
|
||||
offset?: number;
|
||||
/**
|
||||
* Search query text
|
||||
*/
|
||||
q: string;
|
||||
roomId?: string | null;
|
||||
};
|
||||
|
||||
export type V1TranscriptsSearchResponse = SearchResponse;
|
||||
|
||||
export type V1TranscriptGetData = {
|
||||
transcriptId: string;
|
||||
};
|
||||
@@ -633,6 +698,21 @@ export type $OpenApiTs = {
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/transcripts/search": {
|
||||
get: {
|
||||
req: V1TranscriptsSearchData;
|
||||
res: {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: SearchResponse;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HTTPValidationError;
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/transcripts/{transcript_id}": {
|
||||
get: {
|
||||
req: V1TranscriptGetData;
|
||||
|
||||
@@ -23,8 +23,8 @@ export const formatTimeDifference = (seconds: number): string => {
|
||||
hours > 0
|
||||
? `${hours < 10 ? "\u00A0" : ""}${hours}h ago`
|
||||
: minutes > 0
|
||||
? `${minutes < 10 ? "\u00A0" : ""}${minutes}m ago`
|
||||
: `<1m ago`;
|
||||
? `${minutes < 10 ? "\u00A0" : ""}${minutes}m ago`
|
||||
: `<1m ago`;
|
||||
|
||||
return timeString;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user