From 07b29d42a751a7d978253c7cb7a5fde9e2359d10 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Mon, 11 Dec 2023 19:46:05 +0100 Subject: [PATCH] server: add topic duration, and endpoint for getting words group per speaker on a topic --- server/reflector/processors/types.py | 4 ++ server/reflector/views/transcripts.py | 65 +++++++++++++++++++++++++ server/tests/test_transcripts_topics.py | 26 ++++++++++ 3 files changed, 95 insertions(+) create mode 100644 server/tests/test_transcripts_topics.py diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 93e565df..cedb23f9 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -57,6 +57,7 @@ class Word(BaseModel): class TranscriptSegment(BaseModel): text: str start: float + end: float speaker: int = 0 @@ -127,6 +128,7 @@ class Transcript(BaseModel): current_segment = TranscriptSegment( text=word.text, start=word.start, + end=word.end, speaker=word.speaker, ) continue @@ -138,6 +140,7 @@ class Transcript(BaseModel): current_segment = TranscriptSegment( text=word.text, start=word.start, + end=word.end, speaker=word.speaker, ) continue @@ -145,6 +148,7 @@ class Transcript(BaseModel): # if the word is the end of a sentence, and we have enough content, # add the word to the current segment and push it current_segment.text += word.text + current_segment.end = word.end have_punc = PUNC_RE.search(word.text) if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH): diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 682e1576..abb72af4 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -122,6 +122,7 @@ class GetTranscriptTopic(BaseModel): title: str summary: str timestamp: float + duration: float | None transcript: str segments: list[GetTranscriptSegmentTopic] = [] @@ -131,6 +132,7 @@ class GetTranscriptTopic(BaseModel): # In previous version, words were missing # Just output a segment with speaker 0 text = topic.transcript + duration = None segments = [ GetTranscriptSegmentTopic( text=topic.transcript, @@ -142,6 +144,7 @@ class GetTranscriptTopic(BaseModel): # New versions include words transcript = ProcessorTranscript(words=topic.words) text = transcript.text + duration = transcript.duration segments = [ GetTranscriptSegmentTopic( text=segment.text, @@ -157,6 +160,7 @@ class GetTranscriptTopic(BaseModel): timestamp=topic.timestamp, transcript=text, segments=segments, + duration=duration, ) @@ -171,6 +175,44 @@ class GetTranscriptTopicWithWords(GetTranscriptTopic): return instance +class SpeakerWords(BaseModel): + speaker: int + words: list[Word] + + +class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic): + words_per_speaker: list[SpeakerWords] = [] + + @classmethod + def from_transcript_topic(cls, topic: TranscriptTopic): + instance = super().from_transcript_topic(topic) + if topic.words: + words_per_speakers = [] + # group words by speaker + words = [] + for word in topic.words: + if words and words[-1].speaker != word.speaker: + words_per_speakers.append( + SpeakerWords( + speaker=words[-1].speaker, + words=words, + ) + ) + words = [] + words.append(word) + if words: + words_per_speakers.append( + SpeakerWords( + speaker=words[-1].speaker, + words=words, + ) + ) + + instance.words_per_speaker = words_per_speakers + + return instance + + @router.get("/transcripts/{transcript_id}", response_model=GetTranscript) async def transcript_get( transcript_id: str, @@ -247,3 +289,26 @@ async def transcript_get_topics_with_words( GetTranscriptTopicWithWords.from_transcript_topic(topic) for topic in transcript.topics ] + + +@router.get( + "/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker", + response_model=GetTranscriptTopicWithWordsPerSpeaker, +) +async def transcript_get_topics_with_words_per_speaker( + transcript_id: str, + topic_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # get the topic from the transcript + topic = next((t for t in transcript.topics if t.id == topic_id), None) + if not topic: + raise HTTPException(status_code=404, detail="Topic not found") + + # convert to GetTranscriptTopicWithWordsPerSpeaker + return GetTranscriptTopicWithWordsPerSpeaker.from_transcript_topic(topic) diff --git a/server/tests/test_transcripts_topics.py b/server/tests/test_transcripts_topics.py new file mode 100644 index 00000000..cd845b3f --- /dev/null +++ b/server/tests/test_transcripts_topics.py @@ -0,0 +1,26 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_transcript_topics(fake_transcript_with_topics): + from reflector.app import app + + transcript_id = fake_transcript_with_topics.id + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + # check the transcript exists + response = await ac.get(f"/transcripts/{transcript_id}/topics") + assert response.status_code == 200 + assert len(response.json()) == 2 + topic_id = response.json()[0]["id"] + + # get words per speakers + response = await ac.get( + f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker" + ) + assert response.status_code == 200 + data = response.json() + assert len(data["words_per_speaker"]) == 1 + assert data["words_per_speaker"][0]["speaker"] == 0 + assert len(data["words_per_speaker"][0]["words"]) == 2