Merge branch 'feat-api-speaker-reassignment' of github.com:Monadical-SAS/reflector into sara/feat-speaker-reassign

This commit is contained in:
Sara
2023-12-12 11:48:20 +01:00
16 changed files with 992 additions and 7 deletions

View File

@@ -251,6 +251,23 @@ class Transcript(BaseModel):
url += f"?token={token}"
return url
def find_empty_speaker(self) -> int:
"""
Find an empty speaker seat
"""
speakers = set(
word.speaker
for topic in self.topics
for word in topic.words
if word.speaker is not None
)
i = 0
while True:
if i not in speakers:
return i
i += 1
raise Exception("No empty speaker found")
class TranscriptController:
async def get_all(

View File

@@ -57,6 +57,7 @@ class Word(BaseModel):
class TranscriptSegment(BaseModel):
text: str
start: float
end: float
speaker: int = 0
@@ -127,6 +128,7 @@ class Transcript(BaseModel):
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
@@ -138,6 +140,7 @@ class Transcript(BaseModel):
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
@@ -145,6 +148,7 @@ class Transcript(BaseModel):
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):

View File

@@ -122,6 +122,7 @@ class GetTranscriptTopic(BaseModel):
title: str
summary: str
timestamp: float
duration: float | None
transcript: str
segments: list[GetTranscriptSegmentTopic] = []
@@ -131,6 +132,7 @@ class GetTranscriptTopic(BaseModel):
# In previous version, words were missing
# Just output a segment with speaker 0
text = topic.transcript
duration = None
segments = [
GetTranscriptSegmentTopic(
text=topic.transcript,
@@ -142,6 +144,7 @@ class GetTranscriptTopic(BaseModel):
# New versions include words
transcript = ProcessorTranscript(words=topic.words)
text = transcript.text
duration = transcript.duration
segments = [
GetTranscriptSegmentTopic(
text=segment.text,
@@ -157,6 +160,7 @@ class GetTranscriptTopic(BaseModel):
timestamp=topic.timestamp,
transcript=text,
segments=segments,
duration=duration,
)
@@ -171,6 +175,44 @@ class GetTranscriptTopicWithWords(GetTranscriptTopic):
return instance
class SpeakerWords(BaseModel):
speaker: int
words: list[Word]
class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
words_per_speaker: list[SpeakerWords] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
instance = super().from_transcript_topic(topic)
if topic.words:
words_per_speakers = []
# group words by speaker
words = []
for word in topic.words:
if words and words[-1].speaker != word.speaker:
words_per_speakers.append(
SpeakerWords(
speaker=words[-1].speaker,
words=words,
)
)
words = []
words.append(word)
if words:
words_per_speakers.append(
SpeakerWords(
speaker=words[-1].speaker,
words=words,
)
)
instance.words_per_speaker = words_per_speakers
return instance
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(
transcript_id: str,
@@ -247,3 +289,26 @@ async def transcript_get_topics_with_words(
GetTranscriptTopicWithWords.from_transcript_topic(topic)
for topic in transcript.topics
]
@router.get(
"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker",
response_model=GetTranscriptTopicWithWordsPerSpeaker,
)
async def transcript_get_topics_with_words_per_speaker(
transcript_id: str,
topic_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# get the topic from the transcript
topic = next((t for t in transcript.topics if t.id == topic_id), None)
if not topic:
raise HTTPException(status_code=404, detail="Topic not found")
# convert to GetTranscriptTopicWithWordsPerSpeaker
return GetTranscriptTopicWithWordsPerSpeaker.from_transcript_topic(topic)

View File

@@ -59,7 +59,7 @@ async def transcript_add_participant(
)
# ensure the speaker is unique
if transcript.participants:
if participant.speaker is not None:
for p in transcript.participants:
if p.speaker == participant.speaker:
raise HTTPException(

View File

@@ -7,14 +7,15 @@ from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from pydantic import BaseModel, Field
from reflector.db.transcripts import transcripts_controller
router = APIRouter()
class SpeakerAssignment(BaseModel):
speaker: int
speaker: Optional[int] = Field(None, ge=0)
participant: Optional[str] = Field(None)
timestamp_from: float
timestamp_to: float
@@ -23,6 +24,11 @@ class SpeakerAssignmentStatus(BaseModel):
status: str
class SpeakerMerge(BaseModel):
speaker_from: int
speaker_to: int
@router.patch("/transcripts/{transcript_id}/speaker/assign")
async def transcript_assign_speaker(
transcript_id: str,
@@ -37,6 +43,44 @@ async def transcript_assign_speaker(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if assignment.speaker is None and assignment.participant is None:
raise HTTPException(
status_code=400,
detail="Either speaker or participant must be provided",
)
if assignment.speaker is not None and assignment.participant is not None:
raise HTTPException(
status_code=400,
detail="Only one of speaker or participant must be provided",
)
# if it's a participant, search for it
if assignment.speaker is not None:
speaker = assignment.speaker
elif assignment.participant is not None:
participant = next(
(
participant
for participant in transcript.participants
if participant.id == assignment.participant
),
None,
)
if not participant:
raise HTTPException(
status_code=404,
detail="Participant not found",
)
# if the participant does not have a speaker, create one
if participant.speaker is None:
participant.speaker = transcript.find_empty_speaker()
await transcripts_controller.upsert_participant(transcript, participant)
speaker = participant.speaker
# reassign speakers from words in the transcript
ts_from = assignment.timestamp_from
ts_to = assignment.timestamp_to
@@ -45,7 +89,70 @@ async def transcript_assign_speaker(
changed = False
for word in topic.words:
if ts_from <= word.start <= ts_to:
word.speaker = assignment.speaker
word.speaker = speaker
changed = True
if changed:
changed_topics.append(topic)
# batch changes
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"topics": transcript.topics_dump(),
},
)
return SpeakerAssignmentStatus(status="ok")
@router.patch("/transcripts/{transcript_id}/speaker/merge")
async def transcript_merge_speaker(
transcript_id: str,
merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# ensure both speaker are not assigned to the 2 differents participants
participant_from = next(
(
participant
for participant in transcript.participants
if participant.speaker == merge.speaker_from
),
None,
)
participant_to = next(
(
participant
for participant in transcript.participants
if participant.speaker == merge.speaker_to
),
None,
)
if participant_from and participant_to:
raise HTTPException(
status_code=400,
detail="Both speakers are assigned to participants",
)
# reassign speakers from words in the transcript
speaker_from = merge.speaker_from
speaker_to = merge.speaker_to
changed_topics = []
for topic in transcript.topics:
changed = False
for word in topic.words:
if word.speaker == speaker_from:
word.speaker = speaker_to
changed = True
if changed:
changed_topics.append(topic)

View File

@@ -115,3 +115,287 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4
@pytest.mark.asyncio
async def test_transcript_merge_speaker(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# merge speakers
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/merge",
json={
"speaker_from": 1,
"speaker_to": 0,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
@pytest.mark.asyncio
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# create 2 participants
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 1",
},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 2",
},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# check participants speakers
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] is None
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker from a participant
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign participant, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant2_id,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] == 2
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 1
assert topics[1]["words"][1]["speaker"] == 1
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 1
@pytest.mark.asyncio
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# try reassign without any participant_id or speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with both participant_id and speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with non-existing participant_id
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 404

View File

@@ -0,0 +1,26 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_topics(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200
assert len(response.json()) == 2
topic_id = response.json()[0]["id"]
# get words per speakers
response = await ac.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
)
assert response.status_code == 200
data = response.json()
assert len(data["words_per_speaker"]) == 1
assert data["words_per_speaker"][0]["speaker"] == 0
assert len(data["words_per_speaker"][0]["words"]) == 2