server: add topic duration, and endpoint for getting words group per speaker on a topic

This commit is contained in:
2023-12-11 19:46:05 +01:00
parent 6f3d7df507
commit 07b29d42a7
3 changed files with 95 additions and 0 deletions

View File

@@ -57,6 +57,7 @@ class Word(BaseModel):
class TranscriptSegment(BaseModel):
text: str
start: float
end: float
speaker: int = 0
@@ -127,6 +128,7 @@ class Transcript(BaseModel):
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
@@ -138,6 +140,7 @@ class Transcript(BaseModel):
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
@@ -145,6 +148,7 @@ class Transcript(BaseModel):
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):

View File

@@ -122,6 +122,7 @@ class GetTranscriptTopic(BaseModel):
title: str
summary: str
timestamp: float
duration: float | None
transcript: str
segments: list[GetTranscriptSegmentTopic] = []
@@ -131,6 +132,7 @@ class GetTranscriptTopic(BaseModel):
# In previous version, words were missing
# Just output a segment with speaker 0
text = topic.transcript
duration = None
segments = [
GetTranscriptSegmentTopic(
text=topic.transcript,
@@ -142,6 +144,7 @@ class GetTranscriptTopic(BaseModel):
# New versions include words
transcript = ProcessorTranscript(words=topic.words)
text = transcript.text
duration = transcript.duration
segments = [
GetTranscriptSegmentTopic(
text=segment.text,
@@ -157,6 +160,7 @@ class GetTranscriptTopic(BaseModel):
timestamp=topic.timestamp,
transcript=text,
segments=segments,
duration=duration,
)
@@ -171,6 +175,44 @@ class GetTranscriptTopicWithWords(GetTranscriptTopic):
return instance
class SpeakerWords(BaseModel):
speaker: int
words: list[Word]
class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
words_per_speaker: list[SpeakerWords] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
instance = super().from_transcript_topic(topic)
if topic.words:
words_per_speakers = []
# group words by speaker
words = []
for word in topic.words:
if words and words[-1].speaker != word.speaker:
words_per_speakers.append(
SpeakerWords(
speaker=words[-1].speaker,
words=words,
)
)
words = []
words.append(word)
if words:
words_per_speakers.append(
SpeakerWords(
speaker=words[-1].speaker,
words=words,
)
)
instance.words_per_speaker = words_per_speakers
return instance
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(
transcript_id: str,
@@ -247,3 +289,26 @@ async def transcript_get_topics_with_words(
GetTranscriptTopicWithWords.from_transcript_topic(topic)
for topic in transcript.topics
]
@router.get(
"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker",
response_model=GetTranscriptTopicWithWordsPerSpeaker,
)
async def transcript_get_topics_with_words_per_speaker(
transcript_id: str,
topic_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# get the topic from the transcript
topic = next((t for t in transcript.topics if t.id == topic_id), None)
if not topic:
raise HTTPException(status_code=404, detail="Topic not found")
# convert to GetTranscriptTopicWithWordsPerSpeaker
return GetTranscriptTopicWithWordsPerSpeaker.from_transcript_topic(topic)

View File

@@ -0,0 +1,26 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_topics(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200
assert len(response.json()) == 2
topic_id = response.json()[0]["id"]
# get words per speakers
response = await ac.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
)
assert response.status_code == 200
data = response.json()
assert len(data["words_per_speaker"]) == 1
assert data["words_per_speaker"][0]["speaker"] == 0
assert len(data["words_per_speaker"][0]["words"]) == 2