server: add an endpoint to merge speaker

This commit is contained in:
2023-12-11 19:56:24 +01:00
parent 07b29d42a7
commit d790308ec7
2 changed files with 137 additions and 0 deletions

View File

@@ -23,6 +23,11 @@ class SpeakerAssignmentStatus(BaseModel):
status: str status: str
class SpeakerMerge(BaseModel):
speaker_from: int
speaker_to: int
@router.patch("/transcripts/{transcript_id}/speaker/assign") @router.patch("/transcripts/{transcript_id}/speaker/assign")
async def transcript_assign_speaker( async def transcript_assign_speaker(
transcript_id: str, transcript_id: str,
@@ -61,3 +66,66 @@ async def transcript_assign_speaker(
) )
return SpeakerAssignmentStatus(status="ok") return SpeakerAssignmentStatus(status="ok")
@router.patch("/transcripts/{transcript_id}/speaker/merge")
async def transcript_merge_speaker(
transcript_id: str,
merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# ensure both speaker are not assigned to the 2 differents participants
participant_from = next(
(
participant
for participant in transcript.participants
if participant.speaker == merge.speaker_from
),
None,
)
participant_to = next(
(
participant
for participant in transcript.participants
if participant.speaker == merge.speaker_to
),
None,
)
if participant_from and participant_to:
raise HTTPException(
status_code=400,
detail="Both speakers are assigned to participants",
)
# reassign speakers from words in the transcript
speaker_from = merge.speaker_from
speaker_to = merge.speaker_to
changed_topics = []
for topic in transcript.topics:
changed = False
for word in topic.words:
if word.speaker == speaker_from:
word.speaker = speaker_to
changed = True
if changed:
changed_topics.append(topic)
# batch changes
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"topics": transcript.topics_dump(),
},
)
return SpeakerAssignmentStatus(status="ok")

View File

@@ -115,3 +115,72 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
assert topics[0]["segments"][0]["speaker"] == 4 assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1 assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4 assert topics[1]["segments"][0]["speaker"] == 4
@pytest.mark.asyncio
async def test_transcript_merge_speaker(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# merge speakers
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/merge",
json={
"speaker_from": 1,
"speaker_to": 0,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0