From 37b11fdcb8fc3571a2ddb58bbaed85c7b84d3466 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 12 Dec 2023 10:57:21 +0100 Subject: [PATCH] server: allow reassign speaker range using participant_id --- server/reflector/db/transcripts.py | 17 ++ .../views/transcripts_participants.py | 13 +- server/reflector/views/transcripts_speaker.py | 45 +++- server/tests/test_transcripts_speaker.py | 215 ++++++++++++++++++ 4 files changed, 281 insertions(+), 9 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 6550af05..779d6137 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -248,6 +248,23 @@ class Transcript(BaseModel): url += f"?token={token}" return url + def find_empty_speaker(self) -> int: + """ + Find an empty speaker seat + """ + speakers = set( + word.speaker + for topic in self.topics + for word in topic.words + if word.speaker is not None + ) + i = 0 + while True: + if i not in speakers: + return i + i += 1 + raise Exception("No empty speaker found") + class TranscriptController: async def get_all( diff --git a/server/reflector/views/transcripts_participants.py b/server/reflector/views/transcripts_participants.py index 318d6018..fd08405c 100644 --- a/server/reflector/views/transcripts_participants.py +++ b/server/reflector/views/transcripts_participants.py @@ -59,12 +59,13 @@ async def transcript_add_participant( ) # ensure the speaker is unique - for p in transcript.participants: - if p.speaker == participant.speaker: - raise HTTPException( - status_code=400, - detail="Speaker already assigned", - ) + if participant.speaker is not None: + for p in transcript.participants: + if p.speaker == participant.speaker: + raise HTTPException( + status_code=400, + detail="Speaker already assigned", + ) obj = await transcripts_controller.upsert_participant( transcript, TranscriptParticipant(**participant.dict()) diff --git a/server/reflector/views/transcripts_speaker.py b/server/reflector/views/transcripts_speaker.py index 9a4a03bc..0bddad5e 100644 --- a/server/reflector/views/transcripts_speaker.py +++ b/server/reflector/views/transcripts_speaker.py @@ -7,14 +7,15 @@ from typing import Annotated, Optional import reflector.auth as auth from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel +from pydantic import BaseModel, Field from reflector.db.transcripts import transcripts_controller router = APIRouter() class SpeakerAssignment(BaseModel): - speaker: int + speaker: Optional[int] = Field(None, ge=0) + participant: Optional[str] = Field(None) timestamp_from: float timestamp_to: float @@ -42,6 +43,44 @@ async def transcript_assign_speaker( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") + if assignment.speaker is None and assignment.participant is None: + raise HTTPException( + status_code=400, + detail="Either speaker or participant must be provided", + ) + + if assignment.speaker is not None and assignment.participant is not None: + raise HTTPException( + status_code=400, + detail="Only one of speaker or participant must be provided", + ) + + # if it's a participant, search for it + if assignment.speaker is not None: + speaker = assignment.speaker + + elif assignment.participant is not None: + participant = next( + ( + participant + for participant in transcript.participants + if participant.id == assignment.participant + ), + None, + ) + if not participant: + raise HTTPException( + status_code=404, + detail="Participant not found", + ) + + # if the participant does not have a speaker, create one + if participant.speaker is None: + participant.speaker = transcript.find_empty_speaker() + await transcripts_controller.upsert_participant(transcript, participant) + + speaker = participant.speaker + # reassign speakers from words in the transcript ts_from = assignment.timestamp_from ts_to = assignment.timestamp_to @@ -50,7 +89,7 @@ async def transcript_assign_speaker( changed = False for word in topic.words: if ts_from <= word.start <= ts_to: - word.speaker = assignment.speaker + word.speaker = speaker changed = True if changed: changed_topics.append(topic) diff --git a/server/tests/test_transcripts_speaker.py b/server/tests/test_transcripts_speaker.py index 0d98ac66..e3e8034a 100644 --- a/server/tests/test_transcripts_speaker.py +++ b/server/tests/test_transcripts_speaker.py @@ -184,3 +184,218 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics): assert topics[0]["words"][1]["speaker"] == 0 assert topics[1]["words"][0]["speaker"] == 0 assert topics[1]["words"][1]["speaker"] == 0 + + +@pytest.mark.asyncio +async def test_transcript_reassign_with_participant(fake_transcript_with_topics): + from reflector.app import app + + transcript_id = fake_transcript_with_topics.id + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + # check the transcript exists + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + transcript = response.json() + assert len(transcript["participants"]) == 0 + + # create 2 participants + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={ + "name": "Participant 1", + }, + ) + assert response.status_code == 200 + participant1_id = response.json()["id"] + + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={ + "name": "Participant 2", + }, + ) + assert response.status_code == 200 + participant2_id = response.json()["id"] + + # check participants speakers + response = await ac.get(f"/transcripts/{transcript_id}/participants") + assert response.status_code == 200 + participants = response.json() + assert len(participants) == 2 + assert participants[0]["name"] == "Participant 1" + assert participants[0]["speaker"] is None + assert participants[1]["name"] == "Participant 2" + assert participants[1]["speaker"] is None + + # check initial topics of the transcript + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 0 + assert topics[0]["words"][1]["speaker"] == 0 + assert topics[1]["words"][0]["speaker"] == 0 + assert topics[1]["words"][1]["speaker"] == 0 + # check through segments + assert len(topics[0]["segments"]) == 1 + assert topics[0]["segments"][0]["speaker"] == 0 + assert len(topics[1]["segments"]) == 1 + assert topics[1]["segments"][0]["speaker"] == 0 + + # reassign speaker from a participant + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "participant": participant1_id, + "timestamp_from": 0, + "timestamp_to": 1, + }, + ) + assert response.status_code == 200 + + # check participants if speaker has been assigned + # first participant should have 1, because it's not used yet. + response = await ac.get(f"/transcripts/{transcript_id}/participants") + assert response.status_code == 200 + participants = response.json() + assert len(participants) == 2 + assert participants[0]["name"] == "Participant 1" + assert participants[0]["speaker"] == 1 + assert participants[1]["name"] == "Participant 2" + assert participants[1]["speaker"] is None + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 1 + assert topics[0]["words"][1]["speaker"] == 1 + assert topics[1]["words"][0]["speaker"] == 0 + assert topics[1]["words"][1]["speaker"] == 0 + # check segments + assert len(topics[0]["segments"]) == 1 + assert topics[0]["segments"][0]["speaker"] == 1 + assert len(topics[1]["segments"]) == 1 + assert topics[1]["segments"][0]["speaker"] == 0 + + # reassign participant, middle of 2 topics + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "participant": participant2_id, + "timestamp_from": 1, + "timestamp_to": 2.5, + }, + ) + assert response.status_code == 200 + + # check participants if speaker has been assigned + # first participant should have 1, because it's not used yet. + response = await ac.get(f"/transcripts/{transcript_id}/participants") + assert response.status_code == 200 + participants = response.json() + assert len(participants) == 2 + assert participants[0]["name"] == "Participant 1" + assert participants[0]["speaker"] == 1 + assert participants[1]["name"] == "Participant 2" + assert participants[1]["speaker"] == 2 + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 1 + assert topics[0]["words"][1]["speaker"] == 2 + assert topics[1]["words"][0]["speaker"] == 2 + assert topics[1]["words"][1]["speaker"] == 0 + # check segments + assert len(topics[0]["segments"]) == 2 + assert topics[0]["segments"][0]["speaker"] == 1 + assert topics[0]["segments"][1]["speaker"] == 2 + assert len(topics[1]["segments"]) == 2 + assert topics[1]["segments"][0]["speaker"] == 2 + assert topics[1]["segments"][1]["speaker"] == 0 + + # reassign speaker, everything + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "participant": participant1_id, + "timestamp_from": 0, + "timestamp_to": 100, + }, + ) + assert response.status_code == 200 + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 1 + assert topics[0]["words"][1]["speaker"] == 1 + assert topics[1]["words"][0]["speaker"] == 1 + assert topics[1]["words"][1]["speaker"] == 1 + # check segments + assert len(topics[0]["segments"]) == 1 + assert topics[0]["segments"][0]["speaker"] == 1 + assert len(topics[1]["segments"]) == 1 + assert topics[1]["segments"][0]["speaker"] == 1 + + +@pytest.mark.asyncio +async def test_transcript_reassign_edge_cases(fake_transcript_with_topics): + from reflector.app import app + + transcript_id = fake_transcript_with_topics.id + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + # check the transcript exists + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + transcript = response.json() + assert len(transcript["participants"]) == 0 + + # try reassign without any participant_id or speaker + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "timestamp_from": 0, + "timestamp_to": 1, + }, + ) + assert response.status_code == 400 + + # try reassing with both participant_id and speaker + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "participant": "123", + "speaker": 1, + "timestamp_from": 0, + "timestamp_to": 1, + }, + ) + assert response.status_code == 400 + + # try reassing with non-existing participant_id + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "participant": "123", + "timestamp_from": 0, + "timestamp_to": 1, + }, + ) + assert response.status_code == 404