mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
server: add an endpoint to merge speaker
This commit is contained in:
@@ -23,6 +23,11 @@ class SpeakerAssignmentStatus(BaseModel):
|
|||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerMerge(BaseModel):
|
||||||
|
speaker_from: int
|
||||||
|
speaker_to: int
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/transcripts/{transcript_id}/speaker/assign")
|
@router.patch("/transcripts/{transcript_id}/speaker/assign")
|
||||||
async def transcript_assign_speaker(
|
async def transcript_assign_speaker(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
@@ -61,3 +66,66 @@ async def transcript_assign_speaker(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return SpeakerAssignmentStatus(status="ok")
|
return SpeakerAssignmentStatus(status="ok")
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/transcripts/{transcript_id}/speaker/merge")
|
||||||
|
async def transcript_merge_speaker(
|
||||||
|
transcript_id: str,
|
||||||
|
merge: SpeakerMerge,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
) -> SpeakerAssignmentStatus:
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
|
transcript_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not transcript:
|
||||||
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
|
# ensure both speaker are not assigned to the 2 differents participants
|
||||||
|
participant_from = next(
|
||||||
|
(
|
||||||
|
participant
|
||||||
|
for participant in transcript.participants
|
||||||
|
if participant.speaker == merge.speaker_from
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
participant_to = next(
|
||||||
|
(
|
||||||
|
participant
|
||||||
|
for participant in transcript.participants
|
||||||
|
if participant.speaker == merge.speaker_to
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if participant_from and participant_to:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Both speakers are assigned to participants",
|
||||||
|
)
|
||||||
|
|
||||||
|
# reassign speakers from words in the transcript
|
||||||
|
speaker_from = merge.speaker_from
|
||||||
|
speaker_to = merge.speaker_to
|
||||||
|
changed_topics = []
|
||||||
|
for topic in transcript.topics:
|
||||||
|
changed = False
|
||||||
|
for word in topic.words:
|
||||||
|
if word.speaker == speaker_from:
|
||||||
|
word.speaker = speaker_to
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
changed_topics.append(topic)
|
||||||
|
|
||||||
|
# batch changes
|
||||||
|
for topic in changed_topics:
|
||||||
|
transcript.upsert_topic(topic)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"topics": transcript.topics_dump(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return SpeakerAssignmentStatus(status="ok")
|
||||||
|
|||||||
@@ -115,3 +115,72 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics):
|
|||||||
assert topics[0]["segments"][0]["speaker"] == 4
|
assert topics[0]["segments"][0]["speaker"] == 4
|
||||||
assert len(topics[1]["segments"]) == 1
|
assert len(topics[1]["segments"]) == 1
|
||||||
assert topics[1]["segments"][0]["speaker"] == 4
|
assert topics[1]["segments"][0]["speaker"] == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_merge_speaker(fake_transcript_with_topics):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
# check the transcript exists
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check initial topics of the transcript
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
|
||||||
|
# reassign speaker
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"speaker": 1,
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
|
||||||
|
# merge speakers
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/merge",
|
||||||
|
json={
|
||||||
|
"speaker_from": 1,
|
||||||
|
"speaker_to": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user