mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: add topic duration, and endpoint for getting words group per speaker on a topic
This commit is contained in:
@@ -57,6 +57,7 @@ class Word(BaseModel):
|
|||||||
class TranscriptSegment(BaseModel):
|
class TranscriptSegment(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
start: float
|
start: float
|
||||||
|
end: float
|
||||||
speaker: int = 0
|
speaker: int = 0
|
||||||
|
|
||||||
|
|
||||||
@@ -127,6 +128,7 @@ class Transcript(BaseModel):
|
|||||||
current_segment = TranscriptSegment(
|
current_segment = TranscriptSegment(
|
||||||
text=word.text,
|
text=word.text,
|
||||||
start=word.start,
|
start=word.start,
|
||||||
|
end=word.end,
|
||||||
speaker=word.speaker,
|
speaker=word.speaker,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -138,6 +140,7 @@ class Transcript(BaseModel):
|
|||||||
current_segment = TranscriptSegment(
|
current_segment = TranscriptSegment(
|
||||||
text=word.text,
|
text=word.text,
|
||||||
start=word.start,
|
start=word.start,
|
||||||
|
end=word.end,
|
||||||
speaker=word.speaker,
|
speaker=word.speaker,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -145,6 +148,7 @@ class Transcript(BaseModel):
|
|||||||
# if the word is the end of a sentence, and we have enough content,
|
# if the word is the end of a sentence, and we have enough content,
|
||||||
# add the word to the current segment and push it
|
# add the word to the current segment and push it
|
||||||
current_segment.text += word.text
|
current_segment.text += word.text
|
||||||
|
current_segment.end = word.end
|
||||||
|
|
||||||
have_punc = PUNC_RE.search(word.text)
|
have_punc = PUNC_RE.search(word.text)
|
||||||
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
|
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ class GetTranscriptTopic(BaseModel):
|
|||||||
title: str
|
title: str
|
||||||
summary: str
|
summary: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
|
duration: float | None
|
||||||
transcript: str
|
transcript: str
|
||||||
segments: list[GetTranscriptSegmentTopic] = []
|
segments: list[GetTranscriptSegmentTopic] = []
|
||||||
|
|
||||||
@@ -131,6 +132,7 @@ class GetTranscriptTopic(BaseModel):
|
|||||||
# In previous version, words were missing
|
# In previous version, words were missing
|
||||||
# Just output a segment with speaker 0
|
# Just output a segment with speaker 0
|
||||||
text = topic.transcript
|
text = topic.transcript
|
||||||
|
duration = None
|
||||||
segments = [
|
segments = [
|
||||||
GetTranscriptSegmentTopic(
|
GetTranscriptSegmentTopic(
|
||||||
text=topic.transcript,
|
text=topic.transcript,
|
||||||
@@ -142,6 +144,7 @@ class GetTranscriptTopic(BaseModel):
|
|||||||
# New versions include words
|
# New versions include words
|
||||||
transcript = ProcessorTranscript(words=topic.words)
|
transcript = ProcessorTranscript(words=topic.words)
|
||||||
text = transcript.text
|
text = transcript.text
|
||||||
|
duration = transcript.duration
|
||||||
segments = [
|
segments = [
|
||||||
GetTranscriptSegmentTopic(
|
GetTranscriptSegmentTopic(
|
||||||
text=segment.text,
|
text=segment.text,
|
||||||
@@ -157,6 +160,7 @@ class GetTranscriptTopic(BaseModel):
|
|||||||
timestamp=topic.timestamp,
|
timestamp=topic.timestamp,
|
||||||
transcript=text,
|
transcript=text,
|
||||||
segments=segments,
|
segments=segments,
|
||||||
|
duration=duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -171,6 +175,44 @@ class GetTranscriptTopicWithWords(GetTranscriptTopic):
|
|||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerWords(BaseModel):
|
||||||
|
speaker: int
|
||||||
|
words: list[Word]
|
||||||
|
|
||||||
|
|
||||||
|
class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
|
||||||
|
words_per_speaker: list[SpeakerWords] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_transcript_topic(cls, topic: TranscriptTopic):
|
||||||
|
instance = super().from_transcript_topic(topic)
|
||||||
|
if topic.words:
|
||||||
|
words_per_speakers = []
|
||||||
|
# group words by speaker
|
||||||
|
words = []
|
||||||
|
for word in topic.words:
|
||||||
|
if words and words[-1].speaker != word.speaker:
|
||||||
|
words_per_speakers.append(
|
||||||
|
SpeakerWords(
|
||||||
|
speaker=words[-1].speaker,
|
||||||
|
words=words,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
words = []
|
||||||
|
words.append(word)
|
||||||
|
if words:
|
||||||
|
words_per_speakers.append(
|
||||||
|
SpeakerWords(
|
||||||
|
speaker=words[-1].speaker,
|
||||||
|
words=words,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
instance.words_per_speaker = words_per_speakers
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||||
async def transcript_get(
|
async def transcript_get(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
@@ -247,3 +289,26 @@ async def transcript_get_topics_with_words(
|
|||||||
GetTranscriptTopicWithWords.from_transcript_topic(topic)
|
GetTranscriptTopicWithWords.from_transcript_topic(topic)
|
||||||
for topic in transcript.topics
|
for topic in transcript.topics
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker",
|
||||||
|
response_model=GetTranscriptTopicWithWordsPerSpeaker,
|
||||||
|
)
|
||||||
|
async def transcript_get_topics_with_words_per_speaker(
|
||||||
|
transcript_id: str,
|
||||||
|
topic_id: str,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
):
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
|
transcript_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the topic from the transcript
|
||||||
|
topic = next((t for t in transcript.topics if t.id == topic_id), None)
|
||||||
|
if not topic:
|
||||||
|
raise HTTPException(status_code=404, detail="Topic not found")
|
||||||
|
|
||||||
|
# convert to GetTranscriptTopicWithWordsPerSpeaker
|
||||||
|
return GetTranscriptTopicWithWordsPerSpeaker.from_transcript_topic(topic)
|
||||||
|
|||||||
26
server/tests/test_transcripts_topics.py
Normal file
26
server/tests/test_transcripts_topics.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_topics(fake_transcript_with_topics):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
# check the transcript exists
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert len(response.json()) == 2
|
||||||
|
topic_id = response.json()[0]["id"]
|
||||||
|
|
||||||
|
# get words per speakers
|
||||||
|
response = await ac.get(
|
||||||
|
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["words_per_speaker"]) == 1
|
||||||
|
assert data["words_per_speaker"][0]["speaker"] == 0
|
||||||
|
assert len(data["words_per_speaker"][0]["words"]) == 2
|
||||||
Reference in New Issue
Block a user