mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
feat: add transcript format parameter to GET endpoint (#709)
* feat: add transcript format parameter to GET endpoint
Add transcript_format query parameter to /v1/transcripts/{id} endpoint
with support for multiple output formats using discriminated unions.
Formats supported:
- text: Plain speaker dialogue (default)
- text-timestamped: Dialogue with [MM:SS] timestamps
- webvtt-named: WebVTT subtitles with participant names
- json: Structured segments with full metadata
Response models use Pydantic discriminated unions with transcript_format
as discriminator field. POST/PATCH endpoints return GetTranscriptWithParticipants
for minimal responses. GET endpoint returns format-specific models.
* Copy transcript format
* Regenerate types
* Fix transcript formats
* Don't throw inside try
* Remove any type
* Toast share copy errors
* transcript_format exhaustiveness and python idiomatic assert_never
* format_timestamp_mmss clear type definition
* Rename seconds_to_timestamp
* Test transcript format with overlapping speakers
* exact match for vtt multispeaker test
---------
Co-authored-by: Sergey Mankovsky <sergey@monadical.com>
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
17
server/reflector/schemas/transcript_formats.py
Normal file
17
server/reflector/schemas/transcript_formats.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Schema definitions for transcript format types and segments."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
TranscriptFormat = Literal["text", "text-timestamped", "webvtt-named", "json"]
|
||||
|
||||
|
||||
class TranscriptSegment(BaseModel):
|
||||
"""A single transcript segment with speaker and timing information."""
|
||||
|
||||
speaker: int
|
||||
speaker_name: str
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
@@ -7,7 +7,7 @@ This module provides result-based error handling that works in both contexts:
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Union, assert_never
|
||||
|
||||
import celery
|
||||
from celery.result import AsyncResult
|
||||
@@ -18,7 +18,6 @@ from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.pipelines.main_multitrack_pipeline import (
|
||||
task_pipeline_multitrack_process,
|
||||
)
|
||||
from reflector.utils.match import absurd
|
||||
from reflector.utils.string import NonEmptyString
|
||||
|
||||
|
||||
@@ -155,7 +154,7 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult:
|
||||
elif isinstance(config, FileProcessingConfig):
|
||||
return task_pipeline_file_process.delay(transcript_id=config.transcript_id)
|
||||
else:
|
||||
absurd(config)
|
||||
assert_never(config)
|
||||
|
||||
|
||||
def task_is_scheduled_or_active(task_name: str, **kwargs):
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from typing import NoReturn
|
||||
|
||||
|
||||
def assert_exhaustiveness(x: NoReturn) -> NoReturn:
|
||||
"""Provide an assertion at type-check time that this function is never called."""
|
||||
raise AssertionError(f"Invalid value: {x!r}")
|
||||
|
||||
|
||||
def absurd(x: NoReturn) -> NoReturn:
|
||||
return assert_exhaustiveness(x)
|
||||
125
server/reflector/utils/transcript_formats.py
Normal file
125
server/reflector/utils/transcript_formats.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Utilities for converting transcript data to various output formats."""
|
||||
|
||||
import webvtt
|
||||
|
||||
from reflector.db.transcripts import TranscriptParticipant, TranscriptTopic
|
||||
from reflector.processors.types import (
|
||||
Transcript as ProcessorTranscript,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
words_to_segments,
|
||||
)
|
||||
from reflector.schemas.transcript_formats import TranscriptSegment
|
||||
from reflector.utils.webvtt import seconds_to_timestamp
|
||||
|
||||
|
||||
def get_speaker_name(
|
||||
speaker: int, participants: list[TranscriptParticipant] | None
|
||||
) -> str:
|
||||
"""Get participant name for speaker or default to 'Speaker N'."""
|
||||
if participants:
|
||||
for participant in participants:
|
||||
if participant.speaker == speaker:
|
||||
return participant.name
|
||||
return f"Speaker {speaker}"
|
||||
|
||||
|
||||
def format_timestamp_mmss(seconds: float | int) -> str:
|
||||
"""Format seconds as MM:SS timestamp."""
|
||||
minutes = int(seconds // 60)
|
||||
secs = int(seconds % 60)
|
||||
return f"{minutes:02d}:{secs:02d}"
|
||||
|
||||
|
||||
def transcript_to_text(
|
||||
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
|
||||
) -> str:
|
||||
"""Convert transcript topics to plain text with speaker names."""
|
||||
lines = []
|
||||
for topic in topics:
|
||||
if not topic.words:
|
||||
continue
|
||||
|
||||
transcript = ProcessorTranscript(words=topic.words)
|
||||
segments = transcript.as_segments()
|
||||
|
||||
for segment in segments:
|
||||
speaker_name = get_speaker_name(segment.speaker, participants)
|
||||
text = segment.text.strip()
|
||||
lines.append(f"{speaker_name}: {text}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def transcript_to_text_timestamped(
|
||||
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
|
||||
) -> str:
|
||||
"""Convert transcript topics to timestamped text with speaker names."""
|
||||
lines = []
|
||||
for topic in topics:
|
||||
if not topic.words:
|
||||
continue
|
||||
|
||||
transcript = ProcessorTranscript(words=topic.words)
|
||||
segments = transcript.as_segments()
|
||||
|
||||
for segment in segments:
|
||||
speaker_name = get_speaker_name(segment.speaker, participants)
|
||||
timestamp = format_timestamp_mmss(segment.start)
|
||||
text = segment.text.strip()
|
||||
lines.append(f"[{timestamp}] {speaker_name}: {text}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def topics_to_webvtt_named(
|
||||
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
|
||||
) -> str:
|
||||
"""Convert transcript topics to WebVTT format with participant names."""
|
||||
vtt = webvtt.WebVTT()
|
||||
|
||||
for topic in topics:
|
||||
if not topic.words:
|
||||
continue
|
||||
|
||||
segments = words_to_segments(topic.words)
|
||||
|
||||
for segment in segments:
|
||||
speaker_name = get_speaker_name(segment.speaker, participants)
|
||||
text = segment.text.strip()
|
||||
text = f"<v {speaker_name}>{text}"
|
||||
|
||||
caption = webvtt.Caption(
|
||||
start=seconds_to_timestamp(segment.start),
|
||||
end=seconds_to_timestamp(segment.end),
|
||||
text=text,
|
||||
)
|
||||
vtt.captions.append(caption)
|
||||
|
||||
return vtt.content
|
||||
|
||||
|
||||
def transcript_to_json_segments(
|
||||
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
|
||||
) -> list[TranscriptSegment]:
|
||||
"""Convert transcript topics to a flat list of JSON segments."""
|
||||
segments = []
|
||||
|
||||
for topic in topics:
|
||||
if not topic.words:
|
||||
continue
|
||||
|
||||
transcript = ProcessorTranscript(words=topic.words)
|
||||
for segment in transcript.as_segments():
|
||||
speaker_name = get_speaker_name(segment.speaker, participants)
|
||||
segments.append(
|
||||
TranscriptSegment(
|
||||
speaker=segment.speaker,
|
||||
speaker_name=speaker_name,
|
||||
text=segment.text.strip(),
|
||||
start=segment.start,
|
||||
end=segment.end,
|
||||
)
|
||||
)
|
||||
|
||||
return segments
|
||||
@@ -13,7 +13,7 @@ VttTimestamp = Annotated[str, "vtt_timestamp"]
|
||||
WebVTTStr = Annotated[str, "webvtt_str"]
|
||||
|
||||
|
||||
def _seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
|
||||
def seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
|
||||
# lib doesn't do that
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
@@ -37,8 +37,8 @@ def words_to_webvtt(words: list[Word]) -> WebVTTStr:
|
||||
text = f"<v Speaker{segment.speaker}>{text}"
|
||||
|
||||
caption = webvtt.Caption(
|
||||
start=_seconds_to_timestamp(segment.start),
|
||||
end=_seconds_to_timestamp(segment.end),
|
||||
start=seconds_to_timestamp(segment.start),
|
||||
end=seconds_to_timestamp(segment.end),
|
||||
text=text,
|
||||
)
|
||||
vtt.captions.append(caption)
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
from typing import Annotated, Literal, Optional, assert_never
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import apaginate
|
||||
from jose import jwt
|
||||
from pydantic import AwareDatetime, BaseModel, Field, constr, field_serializer
|
||||
from pydantic import (
|
||||
AwareDatetime,
|
||||
BaseModel,
|
||||
Discriminator,
|
||||
Field,
|
||||
constr,
|
||||
field_serializer,
|
||||
)
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import get_database
|
||||
@@ -31,7 +38,14 @@ from reflector.db.transcripts import (
|
||||
)
|
||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||
from reflector.processors.types import Word
|
||||
from reflector.schemas.transcript_formats import TranscriptFormat, TranscriptSegment
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.transcript_formats import (
|
||||
topics_to_webvtt_named,
|
||||
transcript_to_json_segments,
|
||||
transcript_to_text,
|
||||
transcript_to_text_timestamped,
|
||||
)
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
from reflector.zulip import (
|
||||
InvalidMessageError,
|
||||
@@ -88,10 +102,84 @@ class GetTranscriptMinimal(BaseModel):
|
||||
audio_deleted: bool | None = None
|
||||
|
||||
|
||||
class GetTranscript(GetTranscriptMinimal):
|
||||
class GetTranscriptWithParticipants(GetTranscriptMinimal):
|
||||
participants: list[TranscriptParticipant] | None
|
||||
|
||||
|
||||
class GetTranscriptWithText(GetTranscriptWithParticipants):
|
||||
"""
|
||||
Transcript response with plain text format.
|
||||
|
||||
Format: Speaker names followed by their dialogue, one line per segment.
|
||||
Example:
|
||||
John Smith: Hello everyone
|
||||
Jane Doe: Hi there
|
||||
"""
|
||||
|
||||
transcript_format: Literal["text"] = "text"
|
||||
transcript: str
|
||||
|
||||
|
||||
class GetTranscriptWithTextTimestamped(GetTranscriptWithParticipants):
|
||||
"""
|
||||
Transcript response with timestamped text format.
|
||||
|
||||
Format: [MM:SS] timestamp prefix before each speaker and dialogue.
|
||||
Example:
|
||||
[00:00] John Smith: Hello everyone
|
||||
[00:05] Jane Doe: Hi there
|
||||
"""
|
||||
|
||||
transcript_format: Literal["text-timestamped"] = "text-timestamped"
|
||||
transcript: str
|
||||
|
||||
|
||||
class GetTranscriptWithWebVTTNamed(GetTranscriptWithParticipants):
|
||||
"""
|
||||
Transcript response in WebVTT subtitle format with participant names.
|
||||
|
||||
Format: Standard WebVTT with voice tags using participant names.
|
||||
Example:
|
||||
WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:05.000
|
||||
<v John Smith>Hello everyone
|
||||
"""
|
||||
|
||||
transcript_format: Literal["webvtt-named"] = "webvtt-named"
|
||||
transcript: str
|
||||
|
||||
|
||||
class GetTranscriptWithJSON(GetTranscriptWithParticipants):
|
||||
"""
|
||||
Transcript response as structured JSON segments.
|
||||
|
||||
Format: Array of segment objects with speaker info, text, and timing.
|
||||
Example:
|
||||
[
|
||||
{
|
||||
"speaker": 0,
|
||||
"speaker_name": "John Smith",
|
||||
"text": "Hello everyone",
|
||||
"start": 0.0,
|
||||
"end": 5.0
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
transcript_format: Literal["json"] = "json"
|
||||
transcript: list[TranscriptSegment]
|
||||
|
||||
|
||||
GetTranscript = Annotated[
|
||||
GetTranscriptWithText
|
||||
| GetTranscriptWithTextTimestamped
|
||||
| GetTranscriptWithWebVTTNamed
|
||||
| GetTranscriptWithJSON,
|
||||
Discriminator("transcript_format"),
|
||||
]
|
||||
|
||||
|
||||
class CreateTranscript(BaseModel):
|
||||
name: str
|
||||
source_language: str = Field("en")
|
||||
@@ -228,7 +316,7 @@ async def transcripts_search(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transcripts", response_model=GetTranscript)
|
||||
@router.post("/transcripts", response_model=GetTranscriptWithParticipants)
|
||||
async def transcripts_create(
|
||||
info: CreateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
@@ -362,14 +450,72 @@ class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
|
||||
async def transcript_get(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
transcript_format: TranscriptFormat = "text",
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
return await transcripts_controller.get_by_id_for_http(
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
base_data = {
|
||||
"id": transcript.id,
|
||||
"user_id": transcript.user_id,
|
||||
"name": transcript.name,
|
||||
"status": transcript.status,
|
||||
"locked": transcript.locked,
|
||||
"duration": transcript.duration,
|
||||
"title": transcript.title,
|
||||
"short_summary": transcript.short_summary,
|
||||
"long_summary": transcript.long_summary,
|
||||
"created_at": transcript.created_at,
|
||||
"share_mode": transcript.share_mode,
|
||||
"source_language": transcript.source_language,
|
||||
"target_language": transcript.target_language,
|
||||
"reviewed": transcript.reviewed,
|
||||
"meeting_id": transcript.meeting_id,
|
||||
"source_kind": transcript.source_kind,
|
||||
"room_id": transcript.room_id,
|
||||
"audio_deleted": transcript.audio_deleted,
|
||||
"participants": transcript.participants,
|
||||
}
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
if transcript_format == "text":
|
||||
return GetTranscriptWithText(
|
||||
**base_data,
|
||||
transcript_format="text",
|
||||
transcript=transcript_to_text(transcript.topics, transcript.participants),
|
||||
)
|
||||
elif transcript_format == "text-timestamped":
|
||||
return GetTranscriptWithTextTimestamped(
|
||||
**base_data,
|
||||
transcript_format="text-timestamped",
|
||||
transcript=transcript_to_text_timestamped(
|
||||
transcript.topics, transcript.participants
|
||||
),
|
||||
)
|
||||
elif transcript_format == "webvtt-named":
|
||||
return GetTranscriptWithWebVTTNamed(
|
||||
**base_data,
|
||||
transcript_format="webvtt-named",
|
||||
transcript=topics_to_webvtt_named(
|
||||
transcript.topics, transcript.participants
|
||||
),
|
||||
)
|
||||
elif transcript_format == "json":
|
||||
return GetTranscriptWithJSON(
|
||||
**base_data,
|
||||
transcript_format="json",
|
||||
transcript=transcript_to_json_segments(
|
||||
transcript.topics, transcript.participants
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert_never(transcript_format)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/transcripts/{transcript_id}", response_model=GetTranscriptWithParticipants
|
||||
)
|
||||
async def transcript_update(
|
||||
transcript_id: str,
|
||||
info: UpdateTranscript,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Optional, assert_never
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
@@ -15,7 +15,6 @@ from reflector.services.transcript_process import (
|
||||
prepare_transcript_processing,
|
||||
validate_transcript_for_processing,
|
||||
)
|
||||
from reflector.utils.match import absurd
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -44,7 +43,7 @@ async def transcript_process(
|
||||
elif isinstance(validation, ValidationOk):
|
||||
pass
|
||||
else:
|
||||
absurd(validation)
|
||||
assert_never(validation)
|
||||
|
||||
config = await prepare_transcript_processing(validation)
|
||||
|
||||
|
||||
575
server/tests/test_transcript_formats.py
Normal file
575
server/tests/test_transcript_formats.py
Normal file
@@ -0,0 +1,575 @@
|
||||
"""Tests for transcript format conversion functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db.transcripts import TranscriptParticipant, TranscriptTopic
|
||||
from reflector.processors.types import Word
|
||||
from reflector.utils.transcript_formats import (
|
||||
format_timestamp_mmss,
|
||||
get_speaker_name,
|
||||
topics_to_webvtt_named,
|
||||
transcript_to_json_segments,
|
||||
transcript_to_text,
|
||||
transcript_to_text_timestamped,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_speaker_name_with_participants():
|
||||
"""Test speaker name resolution with participants list."""
|
||||
participants = [
|
||||
TranscriptParticipant(id="1", speaker=0, name="John Smith"),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Jane Doe"),
|
||||
]
|
||||
|
||||
assert get_speaker_name(0, participants) == "John Smith"
|
||||
assert get_speaker_name(1, participants) == "Jane Doe"
|
||||
assert get_speaker_name(2, participants) == "Speaker 2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_speaker_name_without_participants():
|
||||
"""Test speaker name resolution without participants list."""
|
||||
assert get_speaker_name(0, None) == "Speaker 0"
|
||||
assert get_speaker_name(1, None) == "Speaker 1"
|
||||
assert get_speaker_name(5, []) == "Speaker 5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_timestamp_mmss():
|
||||
"""Test timestamp formatting to MM:SS."""
|
||||
assert format_timestamp_mmss(0) == "00:00"
|
||||
assert format_timestamp_mmss(5) == "00:05"
|
||||
assert format_timestamp_mmss(65) == "01:05"
|
||||
assert format_timestamp_mmss(125.7) == "02:05"
|
||||
assert format_timestamp_mmss(3661) == "61:01"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_to_text():
|
||||
"""Test plain text format conversion."""
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
id="1",
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||
Word(text=" world.", start=1.0, end=2.0, speaker=0),
|
||||
],
|
||||
),
|
||||
TranscriptTopic(
|
||||
id="2",
|
||||
title="Topic 2",
|
||||
summary="Summary 2",
|
||||
timestamp=2.0,
|
||||
words=[
|
||||
Word(text="How", start=2.0, end=3.0, speaker=1),
|
||||
Word(text=" are", start=3.0, end=4.0, speaker=1),
|
||||
Word(text=" you?", start=4.0, end=5.0, speaker=1),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
participants = [
|
||||
TranscriptParticipant(id="1", speaker=0, name="John Smith"),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Jane Doe"),
|
||||
]
|
||||
|
||||
result = transcript_to_text(topics, participants)
|
||||
lines = result.split("\n")
|
||||
|
||||
assert len(lines) == 2
|
||||
assert lines[0] == "John Smith: Hello world."
|
||||
assert lines[1] == "Jane Doe: How are you?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_to_text_timestamped():
|
||||
"""Test timestamped text format conversion."""
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
id="1",
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||
Word(text=" world.", start=1.0, end=2.0, speaker=0),
|
||||
],
|
||||
),
|
||||
TranscriptTopic(
|
||||
id="2",
|
||||
title="Topic 2",
|
||||
summary="Summary 2",
|
||||
timestamp=65.0,
|
||||
words=[
|
||||
Word(text="How", start=65.0, end=66.0, speaker=1),
|
||||
Word(text=" are", start=66.0, end=67.0, speaker=1),
|
||||
Word(text=" you?", start=67.0, end=68.0, speaker=1),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
participants = [
|
||||
TranscriptParticipant(id="1", speaker=0, name="John Smith"),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Jane Doe"),
|
||||
]
|
||||
|
||||
result = transcript_to_text_timestamped(topics, participants)
|
||||
lines = result.split("\n")
|
||||
|
||||
assert len(lines) == 2
|
||||
assert lines[0] == "[00:00] John Smith: Hello world."
|
||||
assert lines[1] == "[01:05] Jane Doe: How are you?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_topics_to_webvtt_named():
|
||||
"""Test WebVTT format conversion with participant names."""
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
id="1",
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||
Word(text=" world.", start=1.0, end=2.0, speaker=0),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
participants = [
|
||||
TranscriptParticipant(id="1", speaker=0, name="John Smith"),
|
||||
]
|
||||
|
||||
result = topics_to_webvtt_named(topics, participants)
|
||||
|
||||
assert result.startswith("WEBVTT")
|
||||
assert "<v John Smith>" in result
|
||||
assert "00:00:00.000 --> 00:00:02.000" in result
|
||||
assert "Hello world." in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_to_json_segments():
|
||||
"""Test JSON segments format conversion."""
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
id="1",
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||
Word(text=" world.", start=1.0, end=2.0, speaker=0),
|
||||
],
|
||||
),
|
||||
TranscriptTopic(
|
||||
id="2",
|
||||
title="Topic 2",
|
||||
summary="Summary 2",
|
||||
timestamp=2.0,
|
||||
words=[
|
||||
Word(text="How", start=2.0, end=3.0, speaker=1),
|
||||
Word(text=" are", start=3.0, end=4.0, speaker=1),
|
||||
Word(text=" you?", start=4.0, end=5.0, speaker=1),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
participants = [
|
||||
TranscriptParticipant(id="1", speaker=0, name="John Smith"),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Jane Doe"),
|
||||
]
|
||||
|
||||
result = transcript_to_json_segments(topics, participants)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].speaker == 0
|
||||
assert result[0].speaker_name == "John Smith"
|
||||
assert result[0].text == "Hello world."
|
||||
assert result[0].start == 0.0
|
||||
assert result[0].end == 2.0
|
||||
|
||||
assert result[1].speaker == 1
|
||||
assert result[1].speaker_name == "Jane Doe"
|
||||
assert result[1].text == "How are you?"
|
||||
assert result[1].start == 2.0
|
||||
assert result[1].end == 5.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_formats_with_empty_topics():
|
||||
"""Test format conversion with empty topics list."""
|
||||
topics = []
|
||||
participants = []
|
||||
|
||||
assert transcript_to_text(topics, participants) == ""
|
||||
assert transcript_to_text_timestamped(topics, participants) == ""
|
||||
assert "WEBVTT" in topics_to_webvtt_named(topics, participants)
|
||||
assert transcript_to_json_segments(topics, participants) == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_formats_with_empty_words():
|
||||
"""Test format conversion with topics containing no words."""
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
id="1",
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0.0,
|
||||
words=[],
|
||||
),
|
||||
]
|
||||
participants = []
|
||||
|
||||
assert transcript_to_text(topics, participants) == ""
|
||||
assert transcript_to_text_timestamped(topics, participants) == ""
|
||||
assert "WEBVTT" in topics_to_webvtt_named(topics, participants)
|
||||
assert transcript_to_json_segments(topics, participants) == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_formats_with_multiple_speakers():
|
||||
"""Test format conversion with multiple speaker changes."""
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
id="1",
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||
Word(text=" there.", start=1.0, end=2.0, speaker=0),
|
||||
Word(text="Hi", start=2.0, end=3.0, speaker=1),
|
||||
Word(text=" back.", start=3.0, end=4.0, speaker=1),
|
||||
Word(text="Good", start=4.0, end=5.0, speaker=0),
|
||||
Word(text=" morning.", start=5.0, end=6.0, speaker=0),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
participants = [
|
||||
TranscriptParticipant(id="1", speaker=0, name="Alice"),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Bob"),
|
||||
]
|
||||
|
||||
text_result = transcript_to_text(topics, participants)
|
||||
lines = text_result.split("\n")
|
||||
assert len(lines) == 3
|
||||
assert "Alice: Hello there." in lines[0]
|
||||
assert "Bob: Hi back." in lines[1]
|
||||
assert "Alice: Good morning." in lines[2]
|
||||
|
||||
json_result = transcript_to_json_segments(topics, participants)
|
||||
assert len(json_result) == 3
|
||||
assert json_result[0].speaker_name == "Alice"
|
||||
assert json_result[1].speaker_name == "Bob"
|
||||
assert json_result[2].speaker_name == "Alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_formats_with_overlapping_speakers():
|
||||
"""Test format conversion when multiple speakers speak at the same time (overlapping timestamps)."""
|
||||
topics = [
|
||||
TranscriptTopic(
|
||||
id="1",
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0.0,
|
||||
words=[
|
||||
Word(text="Hello", start=0.0, end=0.5, speaker=0),
|
||||
Word(text=" there.", start=0.5, end=1.0, speaker=0),
|
||||
# Speaker 1 overlaps with speaker 0 at 0.5-1.0
|
||||
Word(text="I'm", start=0.5, end=1.0, speaker=1),
|
||||
Word(text=" good.", start=1.0, end=1.5, speaker=1),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
participants = [
|
||||
TranscriptParticipant(id="1", speaker=0, name="Alice"),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Bob"),
|
||||
]
|
||||
|
||||
text_result = transcript_to_text(topics, participants)
|
||||
lines = text_result.split("\n")
|
||||
assert len(lines) >= 2
|
||||
assert any("Alice:" in line for line in lines)
|
||||
assert any("Bob:" in line for line in lines)
|
||||
|
||||
timestamped_result = transcript_to_text_timestamped(topics, participants)
|
||||
timestamped_lines = timestamped_result.split("\n")
|
||||
assert len(timestamped_lines) >= 2
|
||||
assert any("Alice:" in line for line in timestamped_lines)
|
||||
assert any("Bob:" in line for line in timestamped_lines)
|
||||
assert any("[00:00]" in line for line in timestamped_lines)
|
||||
|
||||
webvtt_result = topics_to_webvtt_named(topics, participants)
|
||||
expected_webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:01.000
|
||||
<v Alice>Hello there.
|
||||
|
||||
00:00:00.500 --> 00:00:01.500
|
||||
<v Bob>I'm good.
|
||||
"""
|
||||
assert webvtt_result == expected_webvtt
|
||||
|
||||
segments = transcript_to_json_segments(topics, participants)
|
||||
assert len(segments) >= 2
|
||||
speakers = {seg.speaker for seg in segments}
|
||||
assert 0 in speakers and 1 in speakers
|
||||
|
||||
alice_seg = next(seg for seg in segments if seg.speaker == 0)
|
||||
bob_seg = next(seg for seg in segments if seg.speaker == 1)
|
||||
|
||||
# Verify timestamps overlap: Alice (0.0-1.0) and Bob (0.5-1.5) overlap at 0.5-1.0
|
||||
assert alice_seg.start < bob_seg.end, "Alice segment should start before Bob ends"
|
||||
assert bob_seg.start < alice_seg.end, "Bob segment should start before Alice ends"
|
||||
|
||||
overlap_start = max(alice_seg.start, bob_seg.start)
|
||||
overlap_end = min(alice_seg.end, bob_seg.end)
|
||||
assert (
|
||||
overlap_start < overlap_end
|
||||
), f"Segments should overlap between {overlap_start} and {overlap_end}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_transcript_format_text(client):
|
||||
"""Test GET /transcripts/{id} with transcript_format=text."""
|
||||
response = await client.post("/transcripts", json={"name": "Test transcript"})
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptParticipant,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"participants": [
|
||||
TranscriptParticipant(
|
||||
id="1", speaker=0, name="John Smith"
|
||||
).model_dump(),
|
||||
TranscriptParticipant(id="2", speaker=1, name="Jane Doe").model_dump(),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0,
|
||||
words=[
|
||||
Word(text="Hello", start=0, end=1, speaker=0),
|
||||
Word(text=" world.", start=1, end=2, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
response = await client.get(f"/transcripts/{tid}?transcript_format=text")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["transcript_format"] == "text"
|
||||
assert "transcript" in data
|
||||
assert "John Smith: Hello world." in data["transcript"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_transcript_format_text_timestamped(client):
|
||||
"""Test GET /transcripts/{id} with transcript_format=text-timestamped."""
|
||||
response = await client.post("/transcripts", json={"name": "Test transcript"})
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptParticipant,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"participants": [
|
||||
TranscriptParticipant(
|
||||
id="1", speaker=0, name="John Smith"
|
||||
).model_dump(),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0,
|
||||
words=[
|
||||
Word(text="Hello", start=65, end=66, speaker=0),
|
||||
Word(text=" world.", start=66, end=67, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
f"/transcripts/{tid}?transcript_format=text-timestamped"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["transcript_format"] == "text-timestamped"
|
||||
assert "transcript" in data
|
||||
assert "[01:05] John Smith: Hello world." in data["transcript"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_transcript_format_webvtt_named(client):
|
||||
"""Test GET /transcripts/{id} with transcript_format=webvtt-named."""
|
||||
response = await client.post("/transcripts", json={"name": "Test transcript"})
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptParticipant,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"participants": [
|
||||
TranscriptParticipant(
|
||||
id="1", speaker=0, name="John Smith"
|
||||
).model_dump(),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0,
|
||||
words=[
|
||||
Word(text="Hello", start=0, end=1, speaker=0),
|
||||
Word(text=" world.", start=1, end=2, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
response = await client.get(f"/transcripts/{tid}?transcript_format=webvtt-named")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["transcript_format"] == "webvtt-named"
|
||||
assert "transcript" in data
|
||||
assert "WEBVTT" in data["transcript"]
|
||||
assert "<v John Smith>" in data["transcript"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_transcript_format_json(client):
|
||||
"""Test GET /transcripts/{id} with transcript_format=json."""
|
||||
response = await client.post("/transcripts", json={"name": "Test transcript"})
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
TranscriptParticipant,
|
||||
TranscriptTopic,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.processors.types import Word
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"participants": [
|
||||
TranscriptParticipant(
|
||||
id="1", speaker=0, name="John Smith"
|
||||
).model_dump(),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0,
|
||||
words=[
|
||||
Word(text="Hello", start=0, end=1, speaker=0),
|
||||
Word(text=" world.", start=1, end=2, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
response = await client.get(f"/transcripts/{tid}?transcript_format=json")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["transcript_format"] == "json"
|
||||
assert "transcript" in data
|
||||
assert isinstance(data["transcript"], list)
|
||||
assert len(data["transcript"]) == 1
|
||||
assert data["transcript"][0]["speaker"] == 0
|
||||
assert data["transcript"][0]["speaker_name"] == "John Smith"
|
||||
assert data["transcript"][0]["text"] == "Hello world."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_transcript_format_default_is_text(client):
|
||||
"""Test GET /transcripts/{id} defaults to text format."""
|
||||
response = await client.post("/transcripts", json={"name": "Test transcript"})
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
from reflector.db.transcripts import TranscriptTopic, transcripts_controller
|
||||
from reflector.processors.types import Word
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(tid)
|
||||
|
||||
await transcripts_controller.upsert_topic(
|
||||
transcript,
|
||||
TranscriptTopic(
|
||||
title="Topic 1",
|
||||
summary="Summary 1",
|
||||
timestamp=0,
|
||||
words=[
|
||||
Word(text="Hello", start=0, end=1, speaker=0),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
response = await client.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["transcript_format"] == "text"
|
||||
assert "transcript" in data
|
||||
Reference in New Issue
Block a user