From d790308ec7f5aa5c20db7c0b6665dec1f92288a9 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Mon, 11 Dec 2023 19:56:24 +0100 Subject: [PATCH] server: add an endpoint to merge speaker --- server/reflector/views/transcripts_speaker.py | 68 ++++++++++++++++++ server/tests/test_transcripts_speaker.py | 69 +++++++++++++++++++ 2 files changed, 137 insertions(+) diff --git a/server/reflector/views/transcripts_speaker.py b/server/reflector/views/transcripts_speaker.py index 20489aa0..9a4a03bc 100644 --- a/server/reflector/views/transcripts_speaker.py +++ b/server/reflector/views/transcripts_speaker.py @@ -23,6 +23,11 @@ class SpeakerAssignmentStatus(BaseModel): status: str +class SpeakerMerge(BaseModel): + speaker_from: int + speaker_to: int + + @router.patch("/transcripts/{transcript_id}/speaker/assign") async def transcript_assign_speaker( transcript_id: str, @@ -61,3 +66,66 @@ async def transcript_assign_speaker( ) return SpeakerAssignmentStatus(status="ok") + + +@router.patch("/transcripts/{transcript_id}/speaker/merge") +async def transcript_merge_speaker( + transcript_id: str, + merge: SpeakerMerge, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> SpeakerAssignmentStatus: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + # ensure both speaker are not assigned to the 2 differents participants + participant_from = next( + ( + participant + for participant in transcript.participants + if participant.speaker == merge.speaker_from + ), + None, + ) + participant_to = next( + ( + participant + for participant in transcript.participants + if participant.speaker == merge.speaker_to + ), + None, + ) + if participant_from and participant_to: + raise HTTPException( + status_code=400, + detail="Both speakers are assigned to participants", + ) + + # reassign speakers from words in the transcript + speaker_from = merge.speaker_from + speaker_to = merge.speaker_to + changed_topics = [] + for topic in transcript.topics: + changed = False + for word in topic.words: + if word.speaker == speaker_from: + word.speaker = speaker_to + changed = True + if changed: + changed_topics.append(topic) + + # batch changes + for topic in changed_topics: + transcript.upsert_topic(topic) + await transcripts_controller.update( + transcript, + { + "topics": transcript.topics_dump(), + }, + ) + + return SpeakerAssignmentStatus(status="ok") diff --git a/server/tests/test_transcripts_speaker.py b/server/tests/test_transcripts_speaker.py index 2bca0beb..0d98ac66 100644 --- a/server/tests/test_transcripts_speaker.py +++ b/server/tests/test_transcripts_speaker.py @@ -115,3 +115,72 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics): assert topics[0]["segments"][0]["speaker"] == 4 assert len(topics[1]["segments"]) == 1 assert topics[1]["segments"][0]["speaker"] == 4 + + +@pytest.mark.asyncio +async def test_transcript_merge_speaker(fake_transcript_with_topics): + from reflector.app import app + + transcript_id = fake_transcript_with_topics.id + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + # check the transcript exists + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + + # check initial topics of the transcript + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 0 + assert topics[0]["words"][1]["speaker"] == 0 + assert topics[1]["words"][0]["speaker"] == 0 + assert topics[1]["words"][1]["speaker"] == 0 + + # reassign speaker + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "speaker": 1, + "timestamp_from": 0, + "timestamp_to": 1, + }, + ) + assert response.status_code == 200 + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 1 + assert topics[0]["words"][1]["speaker"] == 1 + assert topics[1]["words"][0]["speaker"] == 0 + assert topics[1]["words"][1]["speaker"] == 0 + + # merge speakers + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/merge", + json={ + "speaker_from": 1, + "speaker_to": 0, + }, + ) + assert response.status_code == 200 + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 0 + assert topics[0]["words"][1]["speaker"] == 0 + assert topics[1]["words"][0]["speaker"] == 0 + assert topics[1]["words"][1]["speaker"] == 0