server: allow reassign speaker range using participant_id

This commit is contained in:
2023-12-12 10:57:21 +01:00
parent d790308ec7
commit 37b11fdcb8
4 changed files with 281 additions and 9 deletions

View File

@@ -248,6 +248,23 @@ class Transcript(BaseModel):
url += f"?token={token}" url += f"?token={token}"
return url return url
def find_empty_speaker(self) -> int:
"""
Find an empty speaker seat
"""
speakers = set(
word.speaker
for topic in self.topics
for word in topic.words
if word.speaker is not None
)
i = 0
while True:
if i not in speakers:
return i
i += 1
raise Exception("No empty speaker found")
class TranscriptController: class TranscriptController:
async def get_all( async def get_all(

View File

@@ -59,12 +59,13 @@ async def transcript_add_participant(
) )
# ensure the speaker is unique # ensure the speaker is unique
for p in transcript.participants: if participant.speaker is not None:
if p.speaker == participant.speaker: for p in transcript.participants:
raise HTTPException( if p.speaker == participant.speaker:
status_code=400, raise HTTPException(
detail="Speaker already assigned", status_code=400,
) detail="Speaker already assigned",
)
obj = await transcripts_controller.upsert_participant( obj = await transcripts_controller.upsert_participant(
transcript, TranscriptParticipant(**participant.dict()) transcript, TranscriptParticipant(**participant.dict())

View File

@@ -7,14 +7,15 @@ from typing import Annotated, Optional
import reflector.auth as auth import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel, Field
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
router = APIRouter() router = APIRouter()
class SpeakerAssignment(BaseModel): class SpeakerAssignment(BaseModel):
speaker: int speaker: Optional[int] = Field(None, ge=0)
participant: Optional[str] = Field(None)
timestamp_from: float timestamp_from: float
timestamp_to: float timestamp_to: float
@@ -42,6 +43,44 @@ async def transcript_assign_speaker(
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
if assignment.speaker is None and assignment.participant is None:
raise HTTPException(
status_code=400,
detail="Either speaker or participant must be provided",
)
if assignment.speaker is not None and assignment.participant is not None:
raise HTTPException(
status_code=400,
detail="Only one of speaker or participant must be provided",
)
# if it's a participant, search for it
if assignment.speaker is not None:
speaker = assignment.speaker
elif assignment.participant is not None:
participant = next(
(
participant
for participant in transcript.participants
if participant.id == assignment.participant
),
None,
)
if not participant:
raise HTTPException(
status_code=404,
detail="Participant not found",
)
# if the participant does not have a speaker, create one
if participant.speaker is None:
participant.speaker = transcript.find_empty_speaker()
await transcripts_controller.upsert_participant(transcript, participant)
speaker = participant.speaker
# reassign speakers from words in the transcript # reassign speakers from words in the transcript
ts_from = assignment.timestamp_from ts_from = assignment.timestamp_from
ts_to = assignment.timestamp_to ts_to = assignment.timestamp_to
@@ -50,7 +89,7 @@ async def transcript_assign_speaker(
changed = False changed = False
for word in topic.words: for word in topic.words:
if ts_from <= word.start <= ts_to: if ts_from <= word.start <= ts_to:
word.speaker = assignment.speaker word.speaker = speaker
changed = True changed = True
if changed: if changed:
changed_topics.append(topic) changed_topics.append(topic)

View File

@@ -184,3 +184,218 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics):
assert topics[0]["words"][1]["speaker"] == 0 assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0 assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0 assert topics[1]["words"][1]["speaker"] == 0
@pytest.mark.asyncio
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# create 2 participants
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 1",
},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 2",
},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# check participants speakers
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] is None
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker from a participant
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign participant, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant2_id,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] == 2
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 1
assert topics[1]["words"][1]["speaker"] == 1
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 1
@pytest.mark.asyncio
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# try reassign without any participant_id or speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with both participant_id and speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with non-existing participant_id
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 404