feat: add transcript format parameter to GET endpoint (#709)

* feat: add transcript format parameter to GET endpoint

Add transcript_format query parameter to /v1/transcripts/{id} endpoint
with support for multiple output formats using discriminated unions.

Formats supported:
- text: Plain speaker dialogue (default)
- text-timestamped: Dialogue with [MM:SS] timestamps
- webvtt-named: WebVTT subtitles with participant names
- json: Structured segments with full metadata

Response models use Pydantic discriminated unions with transcript_format
as discriminator field. POST/PATCH endpoints return GetTranscriptWithParticipants
for minimal responses. GET endpoint returns format-specific models.

* Copy transcript format

* Regenerate types

* Fix transcript formats

* Don't throw inside try

* Remove any type

* Toast share copy errors

* transcript_format exhaustiveness and python idiomatic assert_never

* format_timestamp_mmss clear type definition

* Rename seconds_to_timestamp

* Test transcript format with overlapping speakers

* exact match for vtt multispeaker test

---------

Co-authored-by: Sergey Mankovsky <sergey@monadical.com>
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
2025-11-26 11:51:14 -06:00
committed by GitHub
parent 3aef926203
commit f6ca07505f
12 changed files with 1625 additions and 138 deletions

View File

@@ -0,0 +1,17 @@
"""Schema definitions for transcript format types and segments."""
from typing import Literal
from pydantic import BaseModel
TranscriptFormat = Literal["text", "text-timestamped", "webvtt-named", "json"]
class TranscriptSegment(BaseModel):
"""A single transcript segment with speaker and timing information."""
speaker: int
speaker_name: str
text: str
start: float
end: float

View File

@@ -7,7 +7,7 @@ This module provides result-based error handling that works in both contexts:
"""
from dataclasses import dataclass
from typing import Literal, Union
from typing import Literal, Union, assert_never
import celery
from celery.result import AsyncResult
@@ -18,7 +18,6 @@ from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_multitrack_pipeline import (
task_pipeline_multitrack_process,
)
from reflector.utils.match import absurd
from reflector.utils.string import NonEmptyString
@@ -155,7 +154,7 @@ def dispatch_transcript_processing(config: ProcessingConfig) -> AsyncResult:
elif isinstance(config, FileProcessingConfig):
return task_pipeline_file_process.delay(transcript_id=config.transcript_id)
else:
absurd(config)
assert_never(config)
def task_is_scheduled_or_active(task_name: str, **kwargs):

View File

@@ -1,10 +0,0 @@
from typing import NoReturn
def assert_exhaustiveness(x: NoReturn) -> NoReturn:
"""Provide an assertion at type-check time that this function is never called."""
raise AssertionError(f"Invalid value: {x!r}")
def absurd(x: NoReturn) -> NoReturn:
return assert_exhaustiveness(x)

View File

@@ -0,0 +1,125 @@
"""Utilities for converting transcript data to various output formats."""
import webvtt
from reflector.db.transcripts import TranscriptParticipant, TranscriptTopic
from reflector.processors.types import (
Transcript as ProcessorTranscript,
)
from reflector.processors.types import (
words_to_segments,
)
from reflector.schemas.transcript_formats import TranscriptSegment
from reflector.utils.webvtt import seconds_to_timestamp
def get_speaker_name(
speaker: int, participants: list[TranscriptParticipant] | None
) -> str:
"""Get participant name for speaker or default to 'Speaker N'."""
if participants:
for participant in participants:
if participant.speaker == speaker:
return participant.name
return f"Speaker {speaker}"
def format_timestamp_mmss(seconds: float | int) -> str:
"""Format seconds as MM:SS timestamp."""
minutes = int(seconds // 60)
secs = int(seconds % 60)
return f"{minutes:02d}:{secs:02d}"
def transcript_to_text(
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
) -> str:
"""Convert transcript topics to plain text with speaker names."""
lines = []
for topic in topics:
if not topic.words:
continue
transcript = ProcessorTranscript(words=topic.words)
segments = transcript.as_segments()
for segment in segments:
speaker_name = get_speaker_name(segment.speaker, participants)
text = segment.text.strip()
lines.append(f"{speaker_name}: {text}")
return "\n".join(lines)
def transcript_to_text_timestamped(
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
) -> str:
"""Convert transcript topics to timestamped text with speaker names."""
lines = []
for topic in topics:
if not topic.words:
continue
transcript = ProcessorTranscript(words=topic.words)
segments = transcript.as_segments()
for segment in segments:
speaker_name = get_speaker_name(segment.speaker, participants)
timestamp = format_timestamp_mmss(segment.start)
text = segment.text.strip()
lines.append(f"[{timestamp}] {speaker_name}: {text}")
return "\n".join(lines)
def topics_to_webvtt_named(
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
) -> str:
"""Convert transcript topics to WebVTT format with participant names."""
vtt = webvtt.WebVTT()
for topic in topics:
if not topic.words:
continue
segments = words_to_segments(topic.words)
for segment in segments:
speaker_name = get_speaker_name(segment.speaker, participants)
text = segment.text.strip()
text = f"<v {speaker_name}>{text}"
caption = webvtt.Caption(
start=seconds_to_timestamp(segment.start),
end=seconds_to_timestamp(segment.end),
text=text,
)
vtt.captions.append(caption)
return vtt.content
def transcript_to_json_segments(
topics: list[TranscriptTopic], participants: list[TranscriptParticipant] | None
) -> list[TranscriptSegment]:
"""Convert transcript topics to a flat list of JSON segments."""
segments = []
for topic in topics:
if not topic.words:
continue
transcript = ProcessorTranscript(words=topic.words)
for segment in transcript.as_segments():
speaker_name = get_speaker_name(segment.speaker, participants)
segments.append(
TranscriptSegment(
speaker=segment.speaker,
speaker_name=speaker_name,
text=segment.text.strip(),
start=segment.start,
end=segment.end,
)
)
return segments

View File

@@ -13,7 +13,7 @@ VttTimestamp = Annotated[str, "vtt_timestamp"]
WebVTTStr = Annotated[str, "webvtt_str"]
def _seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
def seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
# lib doesn't do that
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
@@ -37,8 +37,8 @@ def words_to_webvtt(words: list[Word]) -> WebVTTStr:
text = f"<v Speaker{segment.speaker}>{text}"
caption = webvtt.Caption(
start=_seconds_to_timestamp(segment.start),
end=_seconds_to_timestamp(segment.end),
start=seconds_to_timestamp(segment.start),
end=seconds_to_timestamp(segment.end),
text=text,
)
vtt.captions.append(caption)

View File

@@ -1,11 +1,18 @@
from datetime import datetime, timedelta, timezone
from typing import Annotated, Literal, Optional
from typing import Annotated, Literal, Optional, assert_never
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import apaginate
from jose import jwt
from pydantic import AwareDatetime, BaseModel, Field, constr, field_serializer
from pydantic import (
AwareDatetime,
BaseModel,
Discriminator,
Field,
constr,
field_serializer,
)
import reflector.auth as auth
from reflector.db import get_database
@@ -31,7 +38,14 @@ from reflector.db.transcripts import (
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word
from reflector.schemas.transcript_formats import TranscriptFormat, TranscriptSegment
from reflector.settings import settings
from reflector.utils.transcript_formats import (
topics_to_webvtt_named,
transcript_to_json_segments,
transcript_to_text,
transcript_to_text_timestamped,
)
from reflector.ws_manager import get_ws_manager
from reflector.zulip import (
InvalidMessageError,
@@ -88,10 +102,84 @@ class GetTranscriptMinimal(BaseModel):
audio_deleted: bool | None = None
class GetTranscript(GetTranscriptMinimal):
class GetTranscriptWithParticipants(GetTranscriptMinimal):
participants: list[TranscriptParticipant] | None
class GetTranscriptWithText(GetTranscriptWithParticipants):
"""
Transcript response with plain text format.
Format: Speaker names followed by their dialogue, one line per segment.
Example:
John Smith: Hello everyone
Jane Doe: Hi there
"""
transcript_format: Literal["text"] = "text"
transcript: str
class GetTranscriptWithTextTimestamped(GetTranscriptWithParticipants):
"""
Transcript response with timestamped text format.
Format: [MM:SS] timestamp prefix before each speaker and dialogue.
Example:
[00:00] John Smith: Hello everyone
[00:05] Jane Doe: Hi there
"""
transcript_format: Literal["text-timestamped"] = "text-timestamped"
transcript: str
class GetTranscriptWithWebVTTNamed(GetTranscriptWithParticipants):
"""
Transcript response in WebVTT subtitle format with participant names.
Format: Standard WebVTT with voice tags using participant names.
Example:
WEBVTT
00:00:00.000 --> 00:00:05.000
<v John Smith>Hello everyone
"""
transcript_format: Literal["webvtt-named"] = "webvtt-named"
transcript: str
class GetTranscriptWithJSON(GetTranscriptWithParticipants):
"""
Transcript response as structured JSON segments.
Format: Array of segment objects with speaker info, text, and timing.
Example:
[
{
"speaker": 0,
"speaker_name": "John Smith",
"text": "Hello everyone",
"start": 0.0,
"end": 5.0
}
]
"""
transcript_format: Literal["json"] = "json"
transcript: list[TranscriptSegment]
GetTranscript = Annotated[
GetTranscriptWithText
| GetTranscriptWithTextTimestamped
| GetTranscriptWithWebVTTNamed
| GetTranscriptWithJSON,
Discriminator("transcript_format"),
]
class CreateTranscript(BaseModel):
name: str
source_language: str = Field("en")
@@ -228,7 +316,7 @@ async def transcripts_search(
)
@router.post("/transcripts", response_model=GetTranscript)
@router.post("/transcripts", response_model=GetTranscriptWithParticipants)
async def transcripts_create(
info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
@@ -362,14 +450,72 @@ class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
async def transcript_get(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
transcript_format: TranscriptFormat = "text",
):
user_id = user["sub"] if user else None
return await transcripts_controller.get_by_id_for_http(
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
base_data = {
"id": transcript.id,
"user_id": transcript.user_id,
"name": transcript.name,
"status": transcript.status,
"locked": transcript.locked,
"duration": transcript.duration,
"title": transcript.title,
"short_summary": transcript.short_summary,
"long_summary": transcript.long_summary,
"created_at": transcript.created_at,
"share_mode": transcript.share_mode,
"source_language": transcript.source_language,
"target_language": transcript.target_language,
"reviewed": transcript.reviewed,
"meeting_id": transcript.meeting_id,
"source_kind": transcript.source_kind,
"room_id": transcript.room_id,
"audio_deleted": transcript.audio_deleted,
"participants": transcript.participants,
}
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
if transcript_format == "text":
return GetTranscriptWithText(
**base_data,
transcript_format="text",
transcript=transcript_to_text(transcript.topics, transcript.participants),
)
elif transcript_format == "text-timestamped":
return GetTranscriptWithTextTimestamped(
**base_data,
transcript_format="text-timestamped",
transcript=transcript_to_text_timestamped(
transcript.topics, transcript.participants
),
)
elif transcript_format == "webvtt-named":
return GetTranscriptWithWebVTTNamed(
**base_data,
transcript_format="webvtt-named",
transcript=topics_to_webvtt_named(
transcript.topics, transcript.participants
),
)
elif transcript_format == "json":
return GetTranscriptWithJSON(
**base_data,
transcript_format="json",
transcript=transcript_to_json_segments(
transcript.topics, transcript.participants
),
)
else:
assert_never(transcript_format)
@router.patch(
"/transcripts/{transcript_id}", response_model=GetTranscriptWithParticipants
)
async def transcript_update(
transcript_id: str,
info: UpdateTranscript,

View File

@@ -1,4 +1,4 @@
from typing import Annotated, Optional
from typing import Annotated, Optional, assert_never
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
@@ -15,7 +15,6 @@ from reflector.services.transcript_process import (
prepare_transcript_processing,
validate_transcript_for_processing,
)
from reflector.utils.match import absurd
router = APIRouter()
@@ -44,7 +43,7 @@ async def transcript_process(
elif isinstance(validation, ValidationOk):
pass
else:
absurd(validation)
assert_never(validation)
config = await prepare_transcript_processing(validation)