feat: search backend (#537)

* docs: transient docs

* chore: cleanup

* webvtt WIP

* webvtt field

* chore: webvtt tests comments

* chore: remove useless tests

* feat: search TASK.md

* feat: full text search by title/webvtt

* chore: search api task

* feat: search api

* feat: search API

* chore: rm task md

* chore: roll back unnecessary validators

* chore: pr review WIP

* chore: pr review WIP

* chore: pr review

* chore: top imports

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* chore: lint

* chore: lint

* fix: db datetime definitions

* fix: flush() params

* fix: update transcript mutability expectation / test

* fix: update transcript mutability expectation / test

* chore: auto review

* chore: new controller extraction

* chore: new controller extraction

* chore: cleanup

* chore: review WIP

* chore: pr WIP

* chore: remove ci lint

* chore: openapi regeneration

* chore: openapi regeneration

* chore: postgres test doc

* fix: .dockerignore for arm binaries

* fix: .dockerignore for arm binaries

* fix: cap test loops

* fix: cap test loops

* fix: cap test loops

* fix: get_transcript_topics

* chore: remove flow.md docs and claude guidance

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc
This commit is contained in:
Igor Loskutov
2025-08-13 10:03:38 -04:00
committed by GitHub
parent a42ed12982
commit 6fb5cb21c2
29 changed files with 3213 additions and 1493 deletions

1
.gitignore vendored
View File

@@ -14,3 +14,4 @@ data/
www/REFACTOR.md www/REFACTOR.md
www/reload-frontend www/reload-frontend
server/test.sqlite server/test.sqlite
CLAUDE.local.md

View File

@@ -3,10 +3,10 @@
repos: repos:
- repo: local - repo: local
hooks: hooks:
- id: yarn-format - id: format
name: run yarn format name: run format
language: system language: system
entry: bash -c 'cd www && yarn format' entry: bash -c 'cd www && npx prettier --write .'
pass_filenames: false pass_filenames: false
files: ^www/ files: ^www/
@@ -23,8 +23,7 @@ repos:
- id: ruff - id: ruff
args: args:
- --fix - --fix
- --select # Uses select rules from server/pyproject.toml
- I,F401
files: ^server/ files: ^server/
- id: ruff-format - id: ruff-format
files: ^server/ files: ^server/

View File

@@ -44,6 +44,7 @@ services:
working_dir: /app working_dir: /app
volumes: volumes:
- ./www:/app/ - ./www:/app/
- /app/node_modules
env_file: env_file:
- ./www/.env.local - ./www/.env.local

View File

@@ -1 +1,3 @@
Generic single-database configuration. Generic single-database configuration.
Both data migrations and schema migrations must be in migrations.

View File

@@ -0,0 +1,27 @@
"""add_webvtt_field_to_transcript
Revision ID: 0bc0f3ff0111
Revises: b7df9609542c
Create Date: 2025-08-05 19:36:41.740957
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '0bc0f3ff0111'
down_revision: Union[str, None] = 'b7df9609542c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('transcript',
sa.Column('webvtt', sa.Text(), nullable=True)
)
def downgrade() -> None:
op.drop_column('transcript', 'webvtt')

View File

@@ -0,0 +1,47 @@
"""add_full_text_search
Revision ID: 116b2f287eab
Revises: 0bc0f3ff0111
Create Date: 2025-08-07 11:27:38.473517
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '116b2f287eab'
down_revision: Union[str, None] = '0bc0f3ff0111'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
conn = op.get_bind()
if conn.dialect.name != 'postgresql':
return
op.execute("""
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
) STORED
""")
op.create_index(
'idx_transcript_search_vector_en',
'transcript',
['search_vector_en'],
postgresql_using='gin'
)
def downgrade() -> None:
conn = op.get_bind()
if conn.dialect.name != 'postgresql':
return
op.drop_index('idx_transcript_search_vector_en', table_name='transcript')
op.drop_column('transcript', 'search_vector_en')

View File

@@ -0,0 +1,109 @@
"""populate_webvtt_from_topics
Revision ID: 8120ebc75366
Revises: 116b2f287eab
Create Date: 2025-08-11 19:11:01.316947
"""
import json
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision: str = '8120ebc75366'
down_revision: Union[str, None] = '116b2f287eab'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def topics_to_webvtt(topics):
"""Convert topics list to WebVTT format string."""
if not topics:
return None
lines = ["WEBVTT", ""]
for topic in topics:
start_time = format_timestamp(topic.get("start"))
end_time = format_timestamp(topic.get("end"))
text = topic.get("text", "").strip()
if start_time and end_time and text:
lines.append(f"{start_time} --> {end_time}")
lines.append(text)
lines.append("")
return "\n".join(lines).strip()
def format_timestamp(seconds):
"""Format seconds to WebVTT timestamp format (HH:MM:SS.mmm)."""
if seconds is None:
return None
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
return f"{hours:02d}:{minutes:02d}:{secs:06.3f}"
def upgrade() -> None:
"""Populate WebVTT field for all transcripts with topics."""
# Get connection
connection = op.get_bind()
# Query all transcripts with topics
result = connection.execute(
text("SELECT id, topics FROM transcript WHERE topics IS NOT NULL")
)
rows = result.fetchall()
print(f"Found {len(rows)} transcripts with topics")
updated_count = 0
error_count = 0
for row in rows:
transcript_id = row[0]
topics_data = row[1]
if not topics_data:
continue
try:
# Parse JSON if it's a string
if isinstance(topics_data, str):
topics_data = json.loads(topics_data)
# Convert topics to WebVTT format
webvtt_content = topics_to_webvtt(topics_data)
if webvtt_content:
# Update the webvtt field
connection.execute(
text("UPDATE transcript SET webvtt = :webvtt WHERE id = :id"),
{"webvtt": webvtt_content, "id": transcript_id}
)
updated_count += 1
print(f"✓ Updated transcript {transcript_id}")
except Exception as e:
error_count += 1
print(f"✗ Error updating transcript {transcript_id}: {e}")
print(f"\nMigration complete!")
print(f" Updated: {updated_count}")
print(f" Errors: {error_count}")
def downgrade() -> None:
"""Clear WebVTT field for all transcripts."""
op.execute(
text("UPDATE transcript SET webvtt = NULL")
)

View File

@@ -40,6 +40,7 @@ dependencies = [
"llama-index>=0.12.52", "llama-index>=0.12.52",
"llama-index-llms-openai-like>=0.4.0", "llama-index-llms-openai-like>=0.4.0",
"pytest-env>=1.1.5", "pytest-env>=1.1.5",
"webvtt-py>=0.5.0",
] ]
[dependency-groups] [dependency-groups]
@@ -92,5 +93,12 @@ addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"] testpaths = ["tests"]
asyncio_mode = "auto" asyncio_mode = "auto"
[tool.ruff.lint]
select = [
"I", # isort - import sorting
"F401", # unused imports
"PLC0415", # import-outside-top-level - detect inline imports
]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"reflector/processors/summary/summary_builder.py" = ["E501"] "reflector/processors/summary/summary_builder.py" = ["E501"]

View File

@@ -0,0 +1,233 @@
"""Search functionality for transcripts and other entities."""
import logging
from datetime import datetime
from io import StringIO
from typing import Annotated, Any, Dict
import sqlalchemy
import webvtt
from pydantic import BaseModel, Field, constr, field_serializer
from reflector.db import database
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql
logger = logging.getLogger(__name__)
DEFAULT_SEARCH_LIMIT = 20
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
DEFAULT_SNIPPET_MAX_LENGTH = 150
DEFAULT_MAX_SNIPPETS = 3
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
SearchOffsetBase = Annotated[int, Field(ge=0)]
SearchTotalBase = Annotated[int, Field(ge=0)]
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
SearchOffset = Annotated[
SearchOffsetBase, Field(description="Number of results to skip")
]
SearchTotal = Annotated[
SearchTotalBase, Field(description="Total number of search results")
]
class SearchParameters(BaseModel):
"""Validated search parameters for full-text search."""
query_text: SearchQuery
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
offset: SearchOffset = 0
user_id: str | None = None
room_id: str | None = None
class SearchResultDB(BaseModel):
"""Intermediate model for validating raw database results."""
id: str = Field(..., min_length=1)
created_at: datetime
status: str = Field(..., min_length=1)
duration: float | None = Field(None, ge=0)
user_id: str | None = None
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
rank: float = Field(..., ge=0, le=1)
class SearchResult(BaseModel):
"""Public search result model with computed fields."""
id: str = Field(..., min_length=1)
title: str | None = None
user_id: str | None = None
room_id: str | None = None
created_at: datetime
status: str = Field(..., min_length=1)
rank: float = Field(..., ge=0, le=1)
duration: float | None = Field(..., ge=0, description="Duration in seconds")
search_snippets: list[str] = Field(
description="Text snippets around search matches"
)
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
if dt.tzinfo is None:
return dt.isoformat() + "Z"
return dt.isoformat()
class SearchController:
"""Controller for search operations across different entities."""
@staticmethod
def _extract_webvtt_text(webvtt_content: str) -> str:
"""Extract plain text from WebVTT content using webvtt library."""
if not webvtt_content:
return ""
try:
buffer = StringIO(webvtt_content)
vtt = webvtt.read_buffer(buffer)
return " ".join(caption.text for caption in vtt if caption.text)
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
return ""
except AttributeError as e:
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
return ""
@staticmethod
def _generate_snippets(
text: str,
q: SearchQuery,
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: int = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate multiple snippets around all occurrences of search term."""
if not text or not q:
return []
snippets = []
lower_text = text.lower()
search_lower = q.lower()
last_snippet_end = 0
start_pos = 0
while len(snippets) < max_snippets:
match_pos = lower_text.find(search_lower, start_pos)
if match_pos == -1:
if not snippets and search_lower.split():
first_word = search_lower.split()[0]
match_pos = lower_text.find(first_word, start_pos)
if match_pos == -1:
break
else:
break
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
snippet_end = min(
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
)
if snippet_start < last_snippet_end:
start_pos = match_pos + len(search_lower)
continue
snippet = text[snippet_start:snippet_end]
if snippet_start > 0:
snippet = "..." + snippet
if snippet_end < len(text):
snippet = snippet + "..."
snippet = snippet.strip()
if snippet:
snippets.append(snippet)
last_snippet_end = snippet_end
start_pos = match_pos + len(search_lower)
if start_pos >= len(text):
break
return snippets
@classmethod
async def search_transcripts(
cls, params: SearchParameters
) -> tuple[list[SearchResult], int]:
"""
Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count).
"""
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text
)
base_query = sqlalchemy.select(
[
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank"),
]
).where(transcripts.c.search_vector_en.op("@@")(search_query))
if params.user_id:
base_query = base_query.where(transcripts.c.user_id == params.user_id)
if params.room_id:
base_query = base_query.where(transcripts.c.room_id == params.room_id)
query = (
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
.limit(params.limit)
.offset(params.offset)
)
rs = await database.fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
total = await database.fetch_val(count_query)
def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt: str | None = r_dict.pop("webvtt", None)
db_result = SearchResultDB.model_validate(r_dict)
snippets = []
if webvtt:
plain_text = cls._extract_webvtt_text(webvtt)
snippets = cls._generate_snippets(plain_text, params.query_text)
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
results = [_process_result(r) for r in rs]
return results, total
search_controller = SearchController()

View File

@@ -1,9 +1,10 @@
import enum import enum
import json import json
import logging
import os import os
import shutil import shutil
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
@@ -11,13 +12,19 @@ import sqlalchemy
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import Enum from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_ from sqlalchemy.sql import false, or_
from reflector.db import database, metadata from reflector.db import database, metadata
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_transcripts_storage from reflector.storage import get_transcripts_storage
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
from reflector.utils.webvtt import topics_to_webvtt
logger = logging.getLogger(__name__)
class SourceKind(enum.StrEnum): class SourceKind(enum.StrEnum):
@@ -76,6 +83,7 @@ transcripts = sqlalchemy.Table(
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need # same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean), sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String), sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"), sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"), sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"), sqlalchemy.Index("idx_transcript_created_at", "created_at"),
@@ -83,6 +91,29 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Index("idx_transcript_room_id", "room_id"), sqlalchemy.Index("idx_transcript_room_id", "room_id"),
) )
# Add PostgreSQL-specific full-text search column
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
if is_postgresql():
transcripts.append_column(
sqlalchemy.Column(
"search_vector_en",
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
persisted=True,
),
)
)
# Add GIN index for the search vector
transcripts.append_constraint(
sqlalchemy.Index(
"idx_transcript_search_vector_en",
"search_vector_en",
postgresql_using="gin",
)
)
def generate_transcript_name() -> str: def generate_transcript_name() -> str:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@@ -147,14 +178,18 @@ class TranscriptParticipant(BaseModel):
class Transcript(BaseModel): class Transcript(BaseModel):
"""Full transcript model with all fields."""
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name) name: str = Field(default_factory=generate_transcript_name)
status: str = "idle" status: str = "idle"
locked: bool = False
duration: float = 0 duration: float = 0
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
title: str | None = None title: str | None = None
source_kind: SourceKind
room_id: str | None = None
locked: bool = False
short_summary: str | None = None short_summary: str | None = None
long_summary: str | None = None long_summary: str | None = None
topics: list[TranscriptTopic] = [] topics: list[TranscriptTopic] = []
@@ -168,9 +203,8 @@ class Transcript(BaseModel):
meeting_id: str | None = None meeting_id: str | None = None
recording_id: str | None = None recording_id: str | None = None
zulip_message_id: int | None = None zulip_message_id: int | None = None
source_kind: SourceKind
audio_deleted: bool | None = None audio_deleted: bool | None = None
room_id: str | None = None webvtt: str | None = None
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str: def serialize_datetime(self, dt: datetime) -> str:
@@ -271,10 +305,12 @@ class Transcript(BaseModel):
# we need to create an url to be used for diarization # we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible # we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor # from the diarization processor
from datetime import timedelta
from reflector.app import app # TODO don't import app in db
from reflector.views.transcripts import create_access_token from reflector.app import app # noqa: PLC0415
# TODO a util + don''t import views in db
from reflector.views.transcripts import create_access_token # noqa: PLC0415
path = app.url_path_for( path = app.url_path_for(
"transcript_get_audio_mp3", "transcript_get_audio_mp3",
@@ -335,7 +371,6 @@ class TranscriptController:
- `room_id`: filter transcripts by room ID - `room_id`: filter transcripts by room ID
- `search_term`: filter transcripts by search term - `search_term`: filter transcripts by search term
""" """
from reflector.db.rooms import rooms
query = transcripts.select().join( query = transcripts.select().join(
rooms, transcripts.c.room_id == rooms.c.id, isouter=True rooms, transcripts.c.room_id == rooms.c.id, isouter=True
@@ -502,10 +537,17 @@ class TranscriptController:
await database.execute(query) await database.execute(query)
return transcript return transcript
async def update(self, transcript: Transcript, values: dict, mutate=True): # TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self, transcript: Transcript, values: dict, mutate=False
) -> Transcript:
""" """
Update a transcript fields with key/values in values Update a transcript fields with key/values in values.
Returns a copy of the transcript with updated values.
""" """
values = TranscriptController._handle_topics_update(values)
query = ( query = (
transcripts.update() transcripts.update()
.where(transcripts.c.id == transcript.id) .where(transcripts.c.id == transcript.id)
@@ -516,6 +558,28 @@ class TranscriptController:
for key, value in values.items(): for key, value in values.items():
setattr(transcript, key, value) setattr(transcript, key, value)
updated_transcript = transcript.model_copy(update=values)
return updated_transcript
@staticmethod
def _handle_topics_update(values: dict) -> dict:
"""Auto-update WebVTT when topics are updated."""
if values.get("webvtt") is not None:
logger.warn("trying to update read-only webvtt column")
pass
topics_data = values.get("topics")
if topics_data is None:
return values
return {
**values,
"webvtt": topics_to_webvtt(
[TranscriptTopic(**topic_dict) for topic_dict in topics_data]
),
}
async def remove_by_id( async def remove_by_id(
self, self,
transcript_id: str, transcript_id: str,
@@ -558,11 +622,7 @@ class TranscriptController:
Append an event to a transcript Append an event to a transcript
""" """
resp = transcript.add_event(event=event, data=data) resp = transcript.add_event(event=event, data=data)
await self.update( await self.update(transcript, {"events": transcript.events_dump()})
transcript,
{"events": transcript.events_dump()},
mutate=False,
)
return resp return resp
async def upsert_topic( async def upsert_topic(
@@ -574,11 +634,7 @@ class TranscriptController:
Upsert topics to a transcript Upsert topics to a transcript
""" """
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await self.update( await self.update(transcript, {"topics": transcript.topics_dump()})
transcript,
{"topics": transcript.topics_dump()},
mutate=False,
)
async def move_mp3_to_storage(self, transcript: Transcript): async def move_mp3_to_storage(self, transcript: Transcript):
""" """
@@ -603,7 +659,8 @@ class TranscriptController:
) )
# indicate on the transcript that the audio is now on storage # indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"}) # mutates transcript argument
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
# unlink the local file # unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True) transcript.audio_mp3_filename.unlink(missing_ok=True)
@@ -627,11 +684,7 @@ class TranscriptController:
Add/update a participant to a transcript Add/update a participant to a transcript
""" """
result = transcript.upsert_participant(participant) result = transcript.upsert_participant(participant)
await self.update( await self.update(transcript, {"participants": transcript.participants_dump()})
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
return result return result
async def delete_participant( async def delete_participant(
@@ -643,11 +696,7 @@ class TranscriptController:
Delete a participant from a transcript Delete a participant from a transcript
""" """
transcript.delete_participant(participant_id) transcript.delete_participant(participant_id)
await self.update( await self.update(transcript, {"participants": transcript.participants_dump()})
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -0,0 +1,7 @@
"""Database utility functions."""
from reflector.db import database
def is_postgresql() -> bool:
return database.url.scheme and database.url.scheme.startswith('postgresql')

View File

@@ -14,12 +14,15 @@ It is directly linked to our data model.
import asyncio import asyncio
import functools import functools
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Generic
import av
import boto3 import boto3
from celery import chord, current_task, group, shared_task from celery import chord, current_task, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from structlog import BoundLogger as Logger from structlog import BoundLogger as Logger
from reflector.db import database
from reflector.db.meetings import meeting_consent_controller, meetings_controller from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -35,7 +38,7 @@ from reflector.db.transcripts import (
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger from reflector.logger import logger
from reflector.pipelines.runner import PipelineRunner from reflector.pipelines.runner import PipelineMessage, PipelineRunner
from reflector.processors import ( from reflector.processors import (
AudioChunkerProcessor, AudioChunkerProcessor,
AudioDiarizationAutoProcessor, AudioDiarizationAutoProcessor,
@@ -69,8 +72,6 @@ def asynctask(f):
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
async def run_with_db(): async def run_with_db():
from reflector.db import database
await database.connect() await database.connect()
try: try:
return await f(*args, **kwargs) return await f(*args, **kwargs)
@@ -144,7 +145,7 @@ class StrValue(BaseModel):
value: str value: str
class PipelineMainBase(PipelineRunner): class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
transcript_id: str transcript_id: str
ws_room_id: str | None = None ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None ws_manager: WebsocketManager | None = None
@@ -164,7 +165,11 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]: @staticmethod
def wrap_transcript_topics(
topics: list[TranscriptTopic],
) -> list[TitleSummaryWithIdProcessorType]:
# transformation to a pipe-supported format
return [ return [
TitleSummaryWithIdProcessorType( TitleSummaryWithIdProcessorType(
id=topic.id, id=topic.id,
@@ -174,7 +179,7 @@ class PipelineMainBase(PipelineRunner):
duration=topic.duration, duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words), transcript=TranscriptProcessorType(words=topic.words),
) )
for topic in transcript.topics for topic in topics
] ]
@asynccontextmanager @asynccontextmanager
@@ -380,7 +385,7 @@ class PipelineMainLive(PipelineMainBase):
pipeline_post(transcript_id=self.transcript_id) pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase): class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
""" """
Diarize the audio and update topics Diarize the audio and update topics
""" """
@@ -404,11 +409,10 @@ class PipelineMainDiarization(PipelineMainBase):
pipeline.logger.info("Audio is local, skipping diarization") pipeline.logger.info("Audio is local, skipping diarization")
return return
topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url() audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput( audio_diarization_input = AudioDiarizationInput(
audio_url=audio_url, audio_url=audio_url,
topics=topics, topics=self.wrap_transcript_topics(transcript.topics),
) )
# as tempting to use pipeline.push, prefer to use the runner # as tempting to use pipeline.push, prefer to use the runner
@@ -421,7 +425,7 @@ class PipelineMainDiarization(PipelineMainBase):
return pipeline return pipeline
class PipelineMainFromTopics(PipelineMainBase): class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
""" """
Pseudo class for generating a pipeline from topics Pseudo class for generating a pipeline from topics
""" """
@@ -443,7 +447,7 @@ class PipelineMainFromTopics(PipelineMainBase):
pipeline.logger.info(f"{self.__class__.__name__} pipeline created") pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics # push topics
topics = self.get_transcript_topics(transcript) topics = PipelineMainBase.wrap_transcript_topics(transcript.topics)
for topic in topics: for topic in topics:
await self.push(topic) await self.push(topic)
@@ -524,8 +528,6 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
# Convert to mp3 # Convert to mp3
mp3_filename = transcript.audio_mp3_filename mp3_filename = transcript.audio_mp3_filename
import av
with av.open(wav_filename.as_posix()) as in_container: with av.open(wav_filename.as_posix()) as in_container:
in_stream = in_container.streams.audio[0] in_stream = in_container.streams.audio[0]
with av.open(mp3_filename.as_posix(), "w") as out_container: with av.open(mp3_filename.as_posix(), "w") as out_container:
@@ -604,7 +606,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
meeting.id meeting.id
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get fetch consent: {e}") logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
consent_denied = True consent_denied = True
if not consent_denied: if not consent_denied:
@@ -627,7 +629,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
f"Deleted original Whereby recording: {recording.bucket_name}/{recording.object_key}" f"Deleted original Whereby recording: {recording.bucket_name}/{recording.object_key}"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete Whereby recording: {e}") logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible # non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(transcript, {"audio_deleted": True}) await transcripts_controller.update(transcript, {"audio_deleted": True})
@@ -640,7 +642,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
f"Deleted processed audio from storage: {transcript.storage_audio_path}" f"Deleted processed audio from storage: {transcript.storage_audio_path}"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete processed audio: {e}") logger.error(f"Failed to delete processed audio: {e}", exc_info=e)
# 3. Delete local audio files # 3. Delete local audio files
try: try:
@@ -649,7 +651,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
if hasattr(transcript, "audio_wav_filename") and transcript.audio_wav_filename: if hasattr(transcript, "audio_wav_filename") and transcript.audio_wav_filename:
transcript.audio_wav_filename.unlink(missing_ok=True) transcript.audio_wav_filename.unlink(missing_ok=True)
except Exception as e: except Exception as e:
logger.error(f"Failed to delete local audio files: {e}") logger.error(f"Failed to delete local audio files: {e}", exc_info=e)
logger.info("Consent cleanup done") logger.info("Consent cleanup done")
@@ -794,8 +796,6 @@ def pipeline_post(*, transcript_id: str):
@get_transcript @get_transcript
async def pipeline_process(transcript: Transcript, logger: Logger): async def pipeline_process(transcript: Transcript, logger: Logger):
import av
try: try:
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
await transcripts_controller.download_mp3_from_storage(transcript) await transcripts_controller.download_mp3_from_storage(transcript)

View File

@@ -16,14 +16,17 @@ During its lifecycle, it will emit the following status:
""" """
import asyncio import asyncio
from typing import Generic, TypeVar
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from reflector.logger import logger from reflector.logger import logger
from reflector.processors import Pipeline from reflector.processors import Pipeline
PipelineMessage = TypeVar("PipelineMessage")
class PipelineRunner(BaseModel):
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle" status: str = "idle"
@@ -67,7 +70,7 @@ class PipelineRunner(BaseModel):
coro = self.run() coro = self.run()
asyncio.run(coro) asyncio.run(coro)
async def push(self, data): async def push(self, data: PipelineMessage):
""" """
Push data to the pipeline Push data to the pipeline
""" """
@@ -92,7 +95,11 @@ class PipelineRunner(BaseModel):
pass pass
async def _add_cmd( async def _add_cmd(
self, cmd: str, data, max_retries: int = 3, retry_time_limit: int = 3 self,
cmd: str,
data: PipelineMessage,
max_retries: int = 3,
retry_time_limit: int = 3,
): ):
""" """
Enqueue a command to be executed in the runner. Enqueue a command to be executed in the runner.
@@ -143,6 +150,9 @@ class PipelineRunner(BaseModel):
cmd, data = await self._q_cmd.get() cmd, data = await self._q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}") func = getattr(self, f"cmd_{cmd.lower()}")
if func: if func:
if cmd.upper() == "FLUSH":
await func()
else:
await func(data) await func(data)
else: else:
raise Exception(f"Unknown command {cmd}") raise Exception(f"Unknown command {cmd}")
@@ -152,13 +162,13 @@ class PipelineRunner(BaseModel):
self._ev_done.set() self._ev_done.set()
raise raise
async def cmd_push(self, data): async def cmd_push(self, data: PipelineMessage):
if self._is_first_push: if self._is_first_push:
await self._set_status("push") await self._set_status("push")
self._is_first_push = False self._is_first_push = False
await self.pipeline.push(data) await self.pipeline.push(data)
async def cmd_flush(self, data): async def cmd_flush(self):
await self._set_status("flush") await self._set_status("flush")
await self.pipeline.flush() await self.pipeline.flush()
await self._set_status("ended") await self._set_status("ended")

View File

@@ -2,9 +2,10 @@ import io
import re import re
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Annotated
from profanityfilter import ProfanityFilter from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from reflector.redis_cache import redis_cache from reflector.redis_cache import redis_cache
@@ -48,20 +49,70 @@ class AudioFile(BaseModel):
self._path.unlink() self._path.unlink()
# non-negative seconds with float part
Seconds = Annotated[float, Field(ge=0.0, description="Time in seconds with float part")]
class Word(BaseModel): class Word(BaseModel):
text: str text: str
start: float start: Seconds
end: float end: Seconds
speaker: int = 0 speaker: int = 0
class TranscriptSegment(BaseModel): class TranscriptSegment(BaseModel):
text: str text: str
start: float start: Seconds
end: float end: Seconds
speaker: int = 0 speaker: int = 0
def words_to_segments(words: list[Word]) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
class Transcript(BaseModel): class Transcript(BaseModel):
translation: str | None = None translation: str | None = None
words: list[Word] = None words: list[Word] = None
@@ -117,49 +168,7 @@ class Transcript(BaseModel):
return Transcript(text=self.text, translation=self.translation, words=words) return Transcript(text=self.text, translation=self.translation, words=words)
def as_segments(self) -> list[TranscriptSegment]: def as_segments(self) -> list[TranscriptSegment]:
# from a list of word, create a list of segments return words_to_segments(self.words)
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in self.words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
class TitleSummary(BaseModel): class TitleSummary(BaseModel):

View File

@@ -0,0 +1,63 @@
"""WebVTT utilities for generating subtitle files from transcript data."""
from typing import TYPE_CHECKING, Annotated
import webvtt
from reflector.processors.types import Seconds, Word, words_to_segments
if TYPE_CHECKING:
from reflector.db.transcripts import TranscriptTopic
VttTimestamp = Annotated[str, "vtt_timestamp"]
WebVTTStr = Annotated[str, "webvtt_str"]
def _seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
# lib doesn't do that
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
milliseconds = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{milliseconds:03d}"
def words_to_webvtt(words: list[Word]) -> WebVTTStr:
"""Convert words to WebVTT using existing segmentation logic."""
vtt = webvtt.WebVTT()
if not words:
return vtt.content
segments = words_to_segments(words)
for segment in segments:
text = segment.text.strip()
# lib doesn't do that
text = f"<v Speaker{segment.speaker}>{text}"
caption = webvtt.Caption(
start=_seconds_to_timestamp(segment.start),
end=_seconds_to_timestamp(segment.end),
text=text,
)
vtt.captions.append(caption)
return vtt.content
def topics_to_webvtt(topics: list["TranscriptTopic"]) -> WebVTTStr:
if not topics:
return webvtt.WebVTT().content
all_words: list[Word] = []
for topic in topics:
all_words.extend(topic.words)
# assert it's in sequence
for i in range(len(all_words) - 1):
assert (
all_words[i].start <= all_words[i + 1].start
), f"Words are not in sequence: {all_words[i].text} and {all_words[i + 1].text} are not consecutive: {all_words[i].start} > {all_words[i + 1].start}"
return words_to_webvtt(all_words)

View File

@@ -1,15 +1,29 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Annotated, Literal, Optional from typing import Annotated, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate from fastapi_pagination.ext.databases import paginate
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field, field_serializer from pydantic import BaseModel, Field, field_serializer
import reflector.auth as auth import reflector.auth as auth
from reflector.db import database
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.search import (
DEFAULT_SEARCH_LIMIT,
SearchLimit,
SearchLimitBase,
SearchOffset,
SearchOffsetBase,
SearchParameters,
SearchQuery,
SearchQueryBase,
SearchResult,
SearchTotal,
search_controller,
)
from reflector.db.transcripts import ( from reflector.db.transcripts import (
SourceKind, SourceKind,
TranscriptParticipant, TranscriptParticipant,
@@ -100,6 +114,21 @@ class DeletionStatus(BaseModel):
status: str status: str
SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")]
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
SearchOffsetParam = Annotated[
SearchOffsetBase, Query(description="Number of results to skip")
]
class SearchResponse(BaseModel):
results: list[SearchResult]
total: SearchTotal
query: SearchQuery
limit: SearchLimit
offset: SearchOffset
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal]) @router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
async def transcripts_list( async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
@@ -107,8 +136,6 @@ async def transcripts_list(
room_id: str | None = None, room_id: str | None = None,
search_term: str | None = None, search_term: str | None = None,
): ):
from reflector.db import database
if not user and not settings.PUBLIC_MODE: if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")
@@ -127,6 +154,39 @@ async def transcripts_list(
) )
@router.get("/transcripts/search", response_model=SearchResponse)
async def transcripts_search(
q: SearchQueryParam,
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
offset: SearchOffsetParam = 0,
room_id: Optional[str] = None,
user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional)
] = None,
):
"""
Full-text search across transcript titles and content.
"""
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
search_params = SearchParameters(
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
)
results, total = await search_controller.search_transcripts(search_params)
return SearchResponse(
results=results,
total=total,
query=search_params.query_text,
limit=search_params.limit,
offset=search_params.offset,
)
@router.post("/transcripts", response_model=GetTranscript) @router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create( async def transcripts_create(
info: CreateTranscript, info: CreateTranscript,
@@ -273,8 +333,8 @@ async def transcript_update(
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
values = info.dict(exclude_unset=True) values = info.dict(exclude_unset=True)
await transcripts_controller.update(transcript, values) updated_transcript = await transcripts_controller.update(transcript, values)
return transcript return updated_transcript
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus) @router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)

163
server/tests/test_search.py Normal file
View File

@@ -0,0 +1,163 @@
"""Tests for full-text search functionality."""
import json
from datetime import datetime
import pytest
from pydantic import ValidationError
from reflector.db import database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
from reflector.db.utils import is_postgresql
@pytest.mark.asyncio
async def test_search_postgresql_only():
await database.connect()
try:
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params)
assert results == []
assert total == 0
try:
SearchParameters(query_text="")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" ")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
finally:
await database.disconnect()
@pytest.mark.asyncio
async def test_search_input_validation():
await database.connect()
try:
try:
SearchParameters(query_text="")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" \t\n ")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
finally:
await database.disconnect()
@pytest.mark.asyncio
async def test_postgresql_search_with_data():
"""Test full-text search with actual data in PostgreSQL.
Example how to run: DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector_test uv run pytest tests/test_search.py::test_postgresql_search_with_data -v -p no:env
"""
# Skip if not PostgreSQL
if not is_postgresql():
pytest.skip("Test requires PostgreSQL. Set DATABASE_URL=postgresql://...")
await database.connect()
# collision is improbable
test_id = "test-search-e2e-7f3a9b2c"
try:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
test_data = {
"id": test_id,
"name": "Test Search Transcript",
"title": "Engineering Planning Meeting Q4 2024",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(),
"short_summary": "Team discussed search implementation",
"long_summary": "The engineering team met to plan the search feature",
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
Welcome to our engineering planning meeting for Q4 2024.
00:00:10.000 --> 00:00:20.000
Today we'll discuss the implementation of full-text search.
00:00:20.000 --> 00:00:30.000
The search feature should support complex queries with ranking.
00:00:30.000 --> 00:00:40.000
We need to implement PostgreSQL tsvector for better performance.""",
}
await database.execute(transcripts.insert().values(**test_data))
# Test 1: Search for a word in title
params = SearchParameters(query_text="planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
# Test 2: Search for a word in webvtt content
params = SearchParameters(query_text="tsvector")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
# Test 3: Search with multiple words
params = SearchParameters(query_text="engineering planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
# Test 4: Verify SearchResult structure
test_result = next((r for r in results if r.id == test_id), None)
if test_result:
assert test_result.title == "Engineering Planning Meeting Q4 2024"
assert test_result.status == "completed"
assert test_result.duration == 1800.0
assert test_result.source_kind == "room"
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
# Test 5: Search with OR operator
params = SearchParameters(query_text="tsvector OR nosuchword")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
# Test 6: Quoted phrase search
params = SearchParameters(query_text='"full-text search"')
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase"
finally:
await database.execute(transcripts.delete().where(transcripts.c.id == test_id))
await database.disconnect()

View File

@@ -0,0 +1,198 @@
"""Unit tests for search snippet generation."""
from reflector.db.search import SearchController
class TestExtractWebVTT:
"""Test WebVTT text extraction."""
def test_extract_webvtt_with_speakers(self):
"""Test extraction removes speaker tags and timestamps."""
webvtt = """WEBVTT
00:00:00.000 --> 00:00:10.000
<v Speaker0>Hello world, this is a test.
00:00:10.000 --> 00:00:20.000
<v Speaker1>Indeed it is a test of WebVTT parsing.
"""
result = SearchController._extract_webvtt_text(webvtt)
assert "Hello world, this is a test" in result
assert "Indeed it is a test" in result
assert "<v Speaker" not in result
assert "00:00" not in result
assert "-->" not in result
def test_extract_empty_webvtt(self):
"""Test empty WebVTT returns empty string."""
assert SearchController._extract_webvtt_text("") == ""
assert SearchController._extract_webvtt_text(None) == ""
def test_extract_malformed_webvtt(self):
"""Test malformed WebVTT returns empty string."""
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
assert result == ""
class TestGenerateSnippets:
"""Test snippet generation from plain text."""
def test_multiple_matches(self):
"""Test finding multiple occurrences of search term in long text."""
# Create text with Python mentions far apart to get separate snippets
separator = " This is filler text. " * 20 # ~400 chars of padding
text = (
"Python is great for machine learning."
+ separator
+ "Many companies use Python for data science."
+ separator
+ "Python has excellent libraries for analysis."
+ separator
+ "The Python community is very supportive."
)
snippets = SearchController._generate_snippets(text, "Python")
# With enough separation, we should get multiple snippets
assert len(snippets) >= 2 # At least 2 distinct snippets
# Each snippet should contain "Python"
for snippet in snippets:
assert "python" in snippet.lower()
def test_single_match(self):
"""Test single occurrence returns one snippet."""
text = "This document discusses artificial intelligence and its applications."
snippets = SearchController._generate_snippets(text, "artificial intelligence")
assert len(snippets) == 1
assert "artificial intelligence" in snippets[0].lower()
def test_no_matches(self):
"""Test no matches returns empty list."""
text = "This is some random text without the search term."
snippets = SearchController._generate_snippets(text, "machine learning")
assert snippets == []
def test_case_insensitive_search(self):
"""Test search is case insensitive."""
# Add enough text between matches to get separate snippets
text = (
"MACHINE LEARNING is important for modern applications. "
+ "It requires lots of data and computational resources. " * 5 # Padding
+ "Machine Learning rocks and transforms industries. "
+ "Deep learning is a subset of it. " * 5 # More padding
+ "Finally, machine learning will shape our future."
)
snippets = SearchController._generate_snippets(text, "machine learning")
# Should find at least 2 (might be 3 if text is long enough)
assert len(snippets) >= 2
for snippet in snippets:
assert "machine learning" in snippet.lower()
def test_partial_match_fallback(self):
"""Test fallback to first word when exact phrase not found."""
text = "We use machine intelligence for processing."
snippets = SearchController._generate_snippets(text, "machine learning")
# Should fall back to finding "machine"
assert len(snippets) == 1
assert "machine" in snippets[0].lower()
def test_snippet_ellipsis(self):
"""Test ellipsis added for truncated snippets."""
# Long text where match is in the middle
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
assert len(snippets) == 1
assert "..." in snippets[0] # Should have ellipsis
assert "TARGET_WORD" in snippets[0]
def test_overlapping_snippets_deduplicated(self):
"""Test overlapping matches don't create duplicate snippets."""
text = "test test test word" * 10 # Repeated pattern
snippets = SearchController._generate_snippets(text, "test")
# Should get unique snippets, not duplicates
assert len(snippets) <= 3
assert len(snippets) == len(set(snippets)) # All unique
def test_empty_inputs(self):
"""Test empty text or search term returns empty list."""
assert SearchController._generate_snippets("", "search") == []
assert SearchController._generate_snippets("text", "") == []
assert SearchController._generate_snippets("", "") == []
def test_max_snippets_limit(self):
"""Test respects max_snippets parameter."""
# Create text with well-separated occurrences
separator = " filler " * 50 # Ensure snippets don't overlap
text = ("Python is amazing" + separator) * 10 # 10 occurrences
# Test with different limits
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
assert len(snippets_1) == 1
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
assert len(snippets_2) == 2
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
def test_snippet_length(self):
"""Test snippet length is reasonable."""
text = "word " * 200 # Long text
snippets = SearchController._generate_snippets(text, "word")
for snippet in snippets:
# Default max_length is 150 + some context
assert len(snippet) <= 200 # Some buffer for ellipsis
class TestFullPipeline:
"""Test the complete WebVTT to snippets pipeline."""
def test_webvtt_to_snippets_integration(self):
"""Test full pipeline from WebVTT to search snippets."""
# Create WebVTT with well-separated content for multiple snippets
webvtt = (
"""WEBVTT
00:00:00.000 --> 00:00:10.000
<v Speaker0>Let's discuss machine learning applications in modern technology.
00:00:10.000 --> 00:00:20.000
<v Speaker1>"""
+ "Various industries are adopting new technologies. " * 10
+ """
00:00:20.000 --> 00:00:30.000
<v Speaker2>Machine learning is revolutionizing healthcare and diagnostics.
00:00:30.000 --> 00:00:40.000
<v Speaker3>"""
+ "Financial markets show interesting patterns. " * 10
+ """
00:00:40.000 --> 00:00:50.000
<v Speaker0>Machine learning in education provides personalized experiences.
"""
)
# Extract and generate snippets
plain_text = SearchController._extract_webvtt_text(webvtt)
snippets = SearchController._generate_snippets(plain_text, "machine learning")
# Should find at least 2 snippets (text might still be close together)
assert len(snippets) >= 1 # At minimum one snippet containing matches
assert len(snippets) <= 3 # At most 3 by default
# No WebVTT artifacts in snippets
for snippet in snippets:
assert "machine learning" in snippet.lower()
assert "<v Speaker" not in snippet
assert "00:00" not in snippet
assert "-->" not in snippet

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import time
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@@ -39,14 +40,18 @@ async def test_transcript_process(
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# wait for processing to finish # wait for processing to finish (max 10 minutes)
while True: timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
await asyncio.sleep(1) await asyncio.sleep(1)
else:
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# restart the processing # restart the processing
response = await ac.post( response = await ac.post(
@@ -55,14 +60,18 @@ async def test_transcript_process(
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# wait for processing to finish # wait for processing to finish (max 10 minutes)
while True: timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
await asyncio.sleep(1) await asyncio.sleep(1)
else:
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
# check the transcript is ended # check the transcript is ended
transcript = resp.json() transcript = resp.json()

View File

@@ -6,6 +6,7 @@
import asyncio import asyncio
import json import json
import threading import threading
import time
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -21,14 +22,31 @@ class ThreadedUvicorn:
async def start(self): async def start(self):
self.thread.start() self.thread.start()
while not self.server.started: timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (
not self.server.started
and (time.monotonic() - start_time) < timeout_seconds
):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
if not self.server.started:
raise TimeoutError(
f"Server failed to start after {timeout_seconds} seconds"
)
def stop(self): def stop(self):
if self.thread.is_alive(): if self.thread.is_alive():
self.server.should_exit = True self.server.should_exit = True
while self.thread.is_alive(): timeout_seconds = 600 # 10 minutes
continue start_time = time.time()
while (
self.thread.is_alive() and (time.time() - start_time) < timeout_seconds
):
time.sleep(0.1)
if self.thread.is_alive():
raise TimeoutError(
f"Thread failed to stop after {timeout_seconds} seconds"
)
@pytest.fixture @pytest.fixture
@@ -92,12 +110,16 @@ async def test_transcript_rtc_and_websocket(
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws: async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED") print("Test websocket: CONNECTED")
try: try:
while True: timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
msg = await ws.receive_json() msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}") print(f"Test websocket: JSON {msg}")
if msg is None: if msg is None:
break break
events.append(msg) events.append(msg)
else:
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
except Exception as e: except Exception as e:
print(f"Test websocket: EXCEPTION {e}") print(f"Test websocket: EXCEPTION {e}")
finally: finally:
@@ -145,9 +167,12 @@ async def test_transcript_rtc_and_websocket(
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
await asyncio.sleep(1) await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended": if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended") raise TimeoutError("Transcript processing failed")
# stop websocket task # stop websocket task
websocket_task.cancel() websocket_task.cancel()
@@ -253,12 +278,16 @@ async def test_transcript_rtc_and_websocket_and_fr(
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws: async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED") print("Test websocket: CONNECTED")
try: try:
while True: timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
msg = await ws.receive_json() msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}") print(f"Test websocket: JSON {msg}")
if msg is None: if msg is None:
break break
events.append(msg) events.append(msg)
else:
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
except Exception as e: except Exception as e:
print(f"Test websocket: EXCEPTION {e}") print(f"Test websocket: EXCEPTION {e}")
finally: finally:
@@ -310,9 +339,12 @@ async def test_transcript_rtc_and_websocket_and_fr(
if resp.json()["status"] == "ended": if resp.json()["status"] == "ended":
break break
await asyncio.sleep(1) await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended": if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended") raise TimeoutError("Transcript processing failed")
await asyncio.sleep(2) await asyncio.sleep(2)

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import time
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@@ -39,14 +40,18 @@ async def test_transcript_upload_file(
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# wait the processing to finish # wait the processing to finish (max 10 minutes)
while True: timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}") resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"): if resp.json()["status"] in ("ended", "error"):
break break
await asyncio.sleep(1) await asyncio.sleep(1)
else:
pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
# check the transcript is ended # check the transcript is ended
transcript = resp.json() transcript = resp.json()

151
server/tests/test_webvtt.py Normal file
View File

@@ -0,0 +1,151 @@
"""Tests for WebVTT utilities."""
import pytest
from reflector.processors.types import Transcript, Word, words_to_segments
from reflector.utils.webvtt import topics_to_webvtt, words_to_webvtt
class TestWordsToWebVTT:
"""Test words_to_webvtt function with TDD approach."""
def test_empty_words_returns_empty_webvtt(self):
"""Should return empty WebVTT structure for empty words list."""
result = words_to_webvtt([])
assert "WEBVTT" in result
assert result.strip() == "WEBVTT"
def test_single_word_creates_single_caption(self):
"""Should create one caption for a single word."""
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
result = words_to_webvtt(words)
assert "WEBVTT" in result
assert "00:00:00.000 --> 00:00:01.000" in result
assert "Hello" in result
assert "<v Speaker0>" in result
def test_multiple_words_same_speaker_groups_properly(self):
"""Should group consecutive words from same speaker."""
words = [
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text=" world", start=0.5, end=1.0, speaker=0),
]
result = words_to_webvtt(words)
assert "WEBVTT" in result
assert "Hello world" in result
assert "<v Speaker0>" in result
def test_speaker_change_creates_new_caption(self):
"""Should create new caption when speaker changes."""
words = [
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text="Hi", start=0.6, end=1.0, speaker=1),
]
result = words_to_webvtt(words)
lines = result.split("\n")
assert "<v Speaker0>" in result
assert "<v Speaker1>" in result
assert "Hello" in result
assert "Hi" in result
def test_punctuation_creates_segment_boundary(self):
"""Should respect punctuation boundaries from segmentation logic."""
words = [
Word(text="Hello.", start=0.0, end=0.5, speaker=0),
Word(text=" How", start=0.6, end=1.0, speaker=0),
Word(text=" are", start=1.0, end=1.3, speaker=0),
Word(text=" you?", start=1.3, end=1.8, speaker=0),
]
result = words_to_webvtt(words)
assert "WEBVTT" in result
assert "<v Speaker0>" in result
class TestTopicsToWebVTT:
"""Test topics_to_webvtt function."""
def test_empty_topics_returns_empty_webvtt(self):
"""Should handle empty topics list."""
result = topics_to_webvtt([])
assert "WEBVTT" in result
assert result.strip() == "WEBVTT"
def test_extracts_words_from_topics(self):
"""Should extract all words from topics in sequence."""
class MockTopic:
def __init__(self, words):
self.words = words
topics = [
MockTopic(
[
Word(text="First", start=0.0, end=0.5, speaker=1),
Word(text="Second", start=1.0, end=1.5, speaker=0),
]
)
]
result = topics_to_webvtt(topics)
assert "WEBVTT" in result
first_pos = result.find("First")
second_pos = result.find("Second")
assert first_pos < second_pos
def test_non_sequential_topics_raises_assertion(self):
"""Should raise assertion error when words are not in chronological sequence."""
class MockTopic:
def __init__(self, words):
self.words = words
topics = [
MockTopic(
[
Word(text="Second", start=1.0, end=1.5, speaker=0),
Word(text="First", start=0.0, end=0.5, speaker=1),
]
)
]
with pytest.raises(AssertionError) as exc_info:
topics_to_webvtt(topics)
assert "Words are not in sequence" in str(exc_info.value)
assert "Second and First" in str(exc_info.value)
class TestTranscriptWordsToSegments:
"""Test static words_to_segments method (TDD for making it static)."""
def test_static_method_exists(self):
"""Should have static words_to_segments method."""
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
segments = words_to_segments(words)
assert isinstance(segments, list)
assert len(segments) == 1
assert segments[0].text == "Hello"
assert segments[0].speaker == 0
def test_backward_compatibility(self):
"""Should maintain backward compatibility with instance method."""
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
transcript = Transcript(words=words)
segments = transcript.as_segments()
assert isinstance(segments, list)
assert len(segments) == 1
assert segments[0].text == "Hello"

View File

@@ -0,0 +1,34 @@
"""Test WebVTT auto-update functionality and edge cases."""
import pytest
from reflector.db.transcripts import (
TranscriptController,
)
@pytest.mark.asyncio
class TestWebVTTAutoUpdateImplementation:
async def test_handle_topics_update_handles_dict_conversion(self):
"""
Verify that _handle_topics_update() properly converts dict data to TranscriptTopic objects.
"""
values = {
"topics": [
{
"id": "topic1",
"title": "Test",
"summary": "Test",
"timestamp": 0.0,
"words": [
{"text": "Hello", "start": 0.0, "end": 1.0, "speaker": 0}
],
}
]
}
updated_values = TranscriptController._handle_topics_update(values)
assert "webvtt" in updated_values
assert updated_values["webvtt"] is not None
assert "WEBVTT" in updated_values["webvtt"]

View File

@@ -0,0 +1,234 @@
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
import pytest
from reflector.db import database
from reflector.db.transcripts import (
SourceKind,
TranscriptController,
TranscriptTopic,
transcripts,
)
from reflector.processors.types import Word
@pytest.mark.asyncio
class TestWebVTTAutoUpdate:
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self):
"""WebVTT should be None when creating transcript without topics."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
assert result["webvtt"] is None
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_on_upsert_topic(self):
"""WebVTT should update when upserting topics via upsert_topic method."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic = TranscriptTopic(
id="topic1",
title="Test Topic",
summary="Test summary",
timestamp=0.0,
words=[
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text=" world", start=0.5, end=1.0, speaker=0),
],
)
await controller.upsert_topic(transcript, topic)
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
assert "Hello world" in webvtt
assert "<v Speaker0>" in webvtt
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_on_direct_topics_update(self):
"""WebVTT should update when updating topics field directly."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topics_data = [
{
"id": "topic1",
"title": "First Topic",
"summary": "First sentence test",
"timestamp": 0.0,
"words": [
{"text": "First", "start": 0.0, "end": 0.5, "speaker": 0},
{"text": " sentence", "start": 0.5, "end": 1.0, "speaker": 0},
],
}
]
await controller.update(transcript, {"topics": topics_data})
# Fetch from DB
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
assert "First sentence" in webvtt
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_manually_with_handle_topics_update(self):
"""Test that _handle_topics_update works when called manually."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic1 = TranscriptTopic(
id="topic1",
title="Topic 1",
summary="Manual test",
timestamp=0.0,
words=[
Word(text="Manual", start=0.0, end=0.5, speaker=0),
Word(text=" test", start=0.5, end=1.0, speaker=0),
],
)
transcript.upsert_topic(topic1)
values = {"topics": transcript.topics_dump()}
await controller.update(transcript, values)
# Fetch from DB
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
assert "Manual test" in webvtt
assert "<v Speaker0>" in webvtt
finally:
await controller.remove_by_id(transcript.id)
async def test_webvtt_update_with_non_sequential_topics_fails(self):
"""Test that non-sequential topics raise assertion error."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic1 = TranscriptTopic(
id="topic1",
title="Bad Topic",
summary="Bad order test",
timestamp=1.0,
words=[
Word(text="Second", start=2.0, end=2.5, speaker=0),
Word(text="First", start=1.0, end=1.5, speaker=0),
],
)
transcript.upsert_topic(topic1)
values = {"topics": transcript.topics_dump()}
with pytest.raises(AssertionError) as exc_info:
TranscriptController._handle_topics_update(values)
assert "Words are not in sequence" in str(exc_info.value)
finally:
await controller.remove_by_id(transcript.id)
async def test_multiple_speakers_in_webvtt(self):
"""Test WebVTT generation with multiple speakers."""
controller = TranscriptController()
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
topic = TranscriptTopic(
id="topic1",
title="Multi Speaker",
summary="Multi speaker test",
timestamp=0.0,
words=[
Word(text="Hello", start=0.0, end=0.5, speaker=0),
Word(text="Hi", start=1.0, end=1.5, speaker=1),
Word(text="Goodbye", start=2.0, end=2.5, speaker=0),
],
)
transcript.upsert_topic(topic)
values = {"topics": transcript.topics_dump()}
await controller.update(transcript, values)
# Fetch from DB
result = await database.fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "<v Speaker0>" in webvtt
assert "<v Speaker1>" in webvtt
assert "Hello" in webvtt
assert "Hi" in webvtt
assert "Goodbye" in webvtt
finally:
await controller.remove_by_id(transcript.id)

2717
server/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -209,7 +209,6 @@ export const $GetTranscript = {
}, },
created_at: { created_at: {
type: "string", type: "string",
format: "date-time",
title: "Created At", title: "Created At",
}, },
share_mode: { share_mode: {
@@ -395,7 +394,6 @@ export const $GetTranscriptMinimal = {
}, },
created_at: { created_at: {
type: "string", type: "string",
format: "date-time",
title: "Created At", title: "Created At",
}, },
share_mode: { share_mode: {
@@ -758,8 +756,15 @@ export const $Page_GetTranscriptMinimal_ = {
title: "Items", title: "Items",
}, },
total: { total: {
anyOf: [
{
type: "integer", type: "integer",
minimum: 0, minimum: 0,
},
{
type: "null",
},
],
title: "Total", title: "Total",
}, },
page: { page: {
@@ -800,7 +805,7 @@ export const $Page_GetTranscriptMinimal_ = {
}, },
}, },
type: "object", type: "object",
required: ["items", "total", "page", "size"], required: ["items", "page", "size"],
title: "Page[GetTranscriptMinimal]", title: "Page[GetTranscriptMinimal]",
} as const; } as const;
@@ -814,8 +819,15 @@ export const $Page_Room_ = {
title: "Items", title: "Items",
}, },
total: { total: {
anyOf: [
{
type: "integer", type: "integer",
minimum: 0, minimum: 0,
},
{
type: "null",
},
],
title: "Total", title: "Total",
}, },
page: { page: {
@@ -856,7 +868,7 @@ export const $Page_Room_ = {
}, },
}, },
type: "object", type: "object",
required: ["items", "total", "page", "size"], required: ["items", "page", "size"],
title: "Page[Room]", title: "Page[Room]",
} as const; } as const;
@@ -973,6 +985,136 @@ export const $RtcOffer = {
title: "RtcOffer", title: "RtcOffer",
} as const; } as const;
export const $SearchResponse = {
properties: {
results: {
items: {
$ref: "#/components/schemas/SearchResult",
},
type: "array",
title: "Results",
},
total: {
type: "integer",
minimum: 0,
title: "Total",
description: "Total number of search results",
},
query: {
type: "string",
minLength: 1,
title: "Query",
description: "Search query text",
},
limit: {
type: "integer",
maximum: 100,
minimum: 1,
title: "Limit",
description: "Results per page",
},
offset: {
type: "integer",
minimum: 0,
title: "Offset",
description: "Number of results to skip",
},
},
type: "object",
required: ["results", "total", "query", "limit", "offset"],
title: "SearchResponse",
} as const;
export const $SearchResult = {
properties: {
id: {
type: "string",
minLength: 1,
title: "Id",
},
title: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Title",
},
user_id: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "User Id",
},
room_id: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Room Id",
},
created_at: {
type: "string",
title: "Created At",
},
status: {
type: "string",
minLength: 1,
title: "Status",
},
rank: {
type: "number",
maximum: 1,
minimum: 0,
title: "Rank",
},
duration: {
anyOf: [
{
type: "number",
minimum: 0,
},
{
type: "null",
},
],
title: "Duration",
description: "Duration in seconds",
},
search_snippets: {
items: {
type: "string",
},
type: "array",
title: "Search Snippets",
description: "Text snippets around search matches",
},
},
type: "object",
required: [
"id",
"created_at",
"status",
"rank",
"duration",
"search_snippets",
],
title: "SearchResult",
description: "Public search result model with computed fields.",
} as const;
export const $SourceKind = { export const $SourceKind = {
type: "string", type: "string",
enum: ["room", "live", "file"], enum: ["room", "live", "file"],
@@ -1397,6 +1539,7 @@ export const $WherebyWebhookEvent = {
title: "Type", title: "Type",
}, },
data: { data: {
additionalProperties: true,
type: "object", type: "object",
title: "Data", title: "Data",
}, },
@@ -1414,11 +1557,15 @@ export const $Word = {
}, },
start: { start: {
type: "number", type: "number",
minimum: 0,
title: "Start", title: "Start",
description: "Time in seconds with float part",
}, },
end: { end: {
type: "number", type: "number",
minimum: 0,
title: "End", title: "End",
description: "Time in seconds with float part",
}, },
speaker: { speaker: {
type: "integer", type: "integer",

View File

@@ -20,6 +20,8 @@ import type {
V1TranscriptsListResponse, V1TranscriptsListResponse,
V1TranscriptsCreateData, V1TranscriptsCreateData,
V1TranscriptsCreateResponse, V1TranscriptsCreateResponse,
V1TranscriptsSearchData,
V1TranscriptsSearchResponse,
V1TranscriptGetData, V1TranscriptGetData,
V1TranscriptGetResponse, V1TranscriptGetResponse,
V1TranscriptUpdateData, V1TranscriptUpdateData,
@@ -276,6 +278,35 @@ export class DefaultService {
}); });
} }
/**
* Transcripts Search
* Full-text search across transcript titles and content.
* @param data The data for the request.
* @param data.q Search query text
* @param data.limit Results per page
* @param data.offset Number of results to skip
* @param data.roomId
* @returns SearchResponse Successful Response
* @throws ApiError
*/
public v1TranscriptsSearch(
data: V1TranscriptsSearchData,
): CancelablePromise<V1TranscriptsSearchResponse> {
return this.httpRequest.request({
method: "GET",
url: "/v1/transcripts/search",
query: {
q: data.q,
limit: data.limit,
offset: data.offset,
room_id: data.roomId,
},
errors: {
422: "Validation Error",
},
});
}
/** /**
* Transcript Get * Transcript Get
* @param data The data for the request. * @param data The data for the request.

View File

@@ -141,7 +141,7 @@ export type MeetingConsentRequest = {
export type Page_GetTranscriptMinimal_ = { export type Page_GetTranscriptMinimal_ = {
items: Array<GetTranscriptMinimal>; items: Array<GetTranscriptMinimal>;
total: number; total?: number | null;
page: number | null; page: number | null;
size: number | null; size: number | null;
pages?: number | null; pages?: number | null;
@@ -149,7 +149,7 @@ export type Page_GetTranscriptMinimal_ = {
export type Page_Room_ = { export type Page_Room_ = {
items: Array<Room>; items: Array<Room>;
total: number; total?: number | null;
page: number | null; page: number | null;
size: number | null; size: number | null;
pages?: number | null; pages?: number | null;
@@ -181,6 +181,47 @@ export type RtcOffer = {
type: string; type: string;
}; };
export type SearchResponse = {
results: Array<SearchResult>;
/**
* Total number of search results
*/
total: number;
/**
* Search query text
*/
query: string;
/**
* Results per page
*/
limit: number;
/**
* Number of results to skip
*/
offset: number;
};
/**
* Public search result model with computed fields.
*/
export type SearchResult = {
id: string;
title?: string | null;
user_id?: string | null;
room_id?: string | null;
created_at: string;
status: string;
rank: number;
/**
* Duration in seconds
*/
duration: number | null;
/**
* Text snippets around search matches
*/
search_snippets: Array<string>;
};
export type SourceKind = "room" | "live" | "file"; export type SourceKind = "room" | "live" | "file";
export type SpeakerAssignment = { export type SpeakerAssignment = {
@@ -272,7 +313,13 @@ export type WherebyWebhookEvent = {
export type Word = { export type Word = {
text: string; text: string;
/**
* Time in seconds with float part
*/
start: number; start: number;
/**
* Time in seconds with float part
*/
end: number; end: number;
speaker?: number; speaker?: number;
}; };
@@ -346,6 +393,24 @@ export type V1TranscriptsCreateData = {
export type V1TranscriptsCreateResponse = GetTranscript; export type V1TranscriptsCreateResponse = GetTranscript;
export type V1TranscriptsSearchData = {
/**
* Results per page
*/
limit?: number;
/**
* Number of results to skip
*/
offset?: number;
/**
* Search query text
*/
q: string;
roomId?: string | null;
};
export type V1TranscriptsSearchResponse = SearchResponse;
export type V1TranscriptGetData = { export type V1TranscriptGetData = {
transcriptId: string; transcriptId: string;
}; };
@@ -633,6 +698,21 @@ export type $OpenApiTs = {
}; };
}; };
}; };
"/v1/transcripts/search": {
get: {
req: V1TranscriptsSearchData;
res: {
/**
* Successful Response
*/
200: SearchResponse;
/**
* Validation Error
*/
422: HTTPValidationError;
};
};
};
"/v1/transcripts/{transcript_id}": { "/v1/transcripts/{transcript_id}": {
get: { get: {
req: V1TranscriptGetData; req: V1TranscriptGetData;