server: include transcripts words in database, but keep back compatible api

This commit is contained in:
2023-10-20 16:07:12 +02:00
committed by Mathieu Virbel
parent 00eb9bbf3c
commit 21e408b323

View File

@@ -17,6 +17,8 @@ from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
@@ -60,7 +62,8 @@ class TranscriptTopic(BaseModel):
title: str
summary: str
timestamp: float
segments: list[TranscriptSegmentTopic] = []
text: str | None = None
words: list[ProcessorWord] = []
class TranscriptFinalShortSummary(BaseModel):
@@ -304,6 +307,53 @@ async def transcripts_create(
# ==============================================================
class GetTranscriptSegmentTopic(BaseModel):
text: str
start: float
speaker: int
class GetTranscriptTopic(BaseModel):
title: str
summary: str
timestamp: float
text: str
segments: list[GetTranscriptSegmentTopic] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
if not topic.words:
# In previous version, words were missing
# Just output a segment with speaker 0
text = topic.text
segments = [
GetTranscriptSegmentTopic(
text=topic.text,
start=topic.timestamp,
speaker=0,
)
]
else:
# New versions include words
transcript = ProcessorTranscript(words=topic.words)
text = transcript.text
segments = [
GetTranscriptSegmentTopic(
text=segment.text,
start=segment.start,
speaker=segment.speaker,
)
for segment in transcript.as_segments()
]
return cls(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
text=text,
segments=segments,
)
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(
transcript_id: str,
@@ -412,7 +462,10 @@ async def transcript_get_audio_waveform(
return transcript.audio_waveform
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
@router.get(
"/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic],
)
async def transcript_get_topics(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
@@ -421,7 +474,11 @@ async def transcript_get_topics(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript.topics
# convert to GetTranscriptTopic
return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
]
@router.get("/transcripts/{transcript_id}/events")
@@ -498,6 +555,13 @@ async def transcript_events_websocket(
async def handle_rtc_event(event: PipelineEvent, args, data):
try:
return await handle_rtc_event_once(event, args, data)
except Exception:
logger.exception("Error handling RTC event")
async def handle_rtc_event_once(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
@@ -530,14 +594,8 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
segments=[
TranscriptSegmentTopic(
speaker=segment.speaker,
text=segment.text,
timestamp=segment.start,
)
for segment in data.transcript.as_segments()
],
text=data.transcript.text,
words=data.transcript.words,
)
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)