mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: allow reassign speaker range using participant_id
This commit is contained in:
@@ -248,6 +248,23 @@ class Transcript(BaseModel):
|
|||||||
url += f"?token={token}"
|
url += f"?token={token}"
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
def find_empty_speaker(self) -> int:
|
||||||
|
"""
|
||||||
|
Find an empty speaker seat
|
||||||
|
"""
|
||||||
|
speakers = set(
|
||||||
|
word.speaker
|
||||||
|
for topic in self.topics
|
||||||
|
for word in topic.words
|
||||||
|
if word.speaker is not None
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
if i not in speakers:
|
||||||
|
return i
|
||||||
|
i += 1
|
||||||
|
raise Exception("No empty speaker found")
|
||||||
|
|
||||||
|
|
||||||
class TranscriptController:
|
class TranscriptController:
|
||||||
async def get_all(
|
async def get_all(
|
||||||
|
|||||||
@@ -59,12 +59,13 @@ async def transcript_add_participant(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ensure the speaker is unique
|
# ensure the speaker is unique
|
||||||
for p in transcript.participants:
|
if participant.speaker is not None:
|
||||||
if p.speaker == participant.speaker:
|
for p in transcript.participants:
|
||||||
raise HTTPException(
|
if p.speaker == participant.speaker:
|
||||||
status_code=400,
|
raise HTTPException(
|
||||||
detail="Speaker already assigned",
|
status_code=400,
|
||||||
)
|
detail="Speaker already assigned",
|
||||||
|
)
|
||||||
|
|
||||||
obj = await transcripts_controller.upsert_participant(
|
obj = await transcripts_controller.upsert_participant(
|
||||||
transcript, TranscriptParticipant(**participant.dict())
|
transcript, TranscriptParticipant(**participant.dict())
|
||||||
|
|||||||
@@ -7,14 +7,15 @@ from typing import Annotated, Optional
|
|||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class SpeakerAssignment(BaseModel):
|
class SpeakerAssignment(BaseModel):
|
||||||
speaker: int
|
speaker: Optional[int] = Field(None, ge=0)
|
||||||
|
participant: Optional[str] = Field(None)
|
||||||
timestamp_from: float
|
timestamp_from: float
|
||||||
timestamp_to: float
|
timestamp_to: float
|
||||||
|
|
||||||
@@ -42,6 +43,44 @@ async def transcript_assign_speaker(
|
|||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
|
if assignment.speaker is None and assignment.participant is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Either speaker or participant must be provided",
|
||||||
|
)
|
||||||
|
|
||||||
|
if assignment.speaker is not None and assignment.participant is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Only one of speaker or participant must be provided",
|
||||||
|
)
|
||||||
|
|
||||||
|
# if it's a participant, search for it
|
||||||
|
if assignment.speaker is not None:
|
||||||
|
speaker = assignment.speaker
|
||||||
|
|
||||||
|
elif assignment.participant is not None:
|
||||||
|
participant = next(
|
||||||
|
(
|
||||||
|
participant
|
||||||
|
for participant in transcript.participants
|
||||||
|
if participant.id == assignment.participant
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not participant:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Participant not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# if the participant does not have a speaker, create one
|
||||||
|
if participant.speaker is None:
|
||||||
|
participant.speaker = transcript.find_empty_speaker()
|
||||||
|
await transcripts_controller.upsert_participant(transcript, participant)
|
||||||
|
|
||||||
|
speaker = participant.speaker
|
||||||
|
|
||||||
# reassign speakers from words in the transcript
|
# reassign speakers from words in the transcript
|
||||||
ts_from = assignment.timestamp_from
|
ts_from = assignment.timestamp_from
|
||||||
ts_to = assignment.timestamp_to
|
ts_to = assignment.timestamp_to
|
||||||
@@ -50,7 +89,7 @@ async def transcript_assign_speaker(
|
|||||||
changed = False
|
changed = False
|
||||||
for word in topic.words:
|
for word in topic.words:
|
||||||
if ts_from <= word.start <= ts_to:
|
if ts_from <= word.start <= ts_to:
|
||||||
word.speaker = assignment.speaker
|
word.speaker = speaker
|
||||||
changed = True
|
changed = True
|
||||||
if changed:
|
if changed:
|
||||||
changed_topics.append(topic)
|
changed_topics.append(topic)
|
||||||
|
|||||||
@@ -184,3 +184,218 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics):
|
|||||||
assert topics[0]["words"][1]["speaker"] == 0
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
assert topics[1]["words"][0]["speaker"] == 0
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
assert topics[1]["words"][1]["speaker"] == 0
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
# check the transcript exists
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
transcript = response.json()
|
||||||
|
assert len(transcript["participants"]) == 0
|
||||||
|
|
||||||
|
# create 2 participants
|
||||||
|
response = await ac.post(
|
||||||
|
f"/transcripts/{transcript_id}/participants",
|
||||||
|
json={
|
||||||
|
"name": "Participant 1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
participant1_id = response.json()["id"]
|
||||||
|
|
||||||
|
response = await ac.post(
|
||||||
|
f"/transcripts/{transcript_id}/participants",
|
||||||
|
json={
|
||||||
|
"name": "Participant 2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
participant2_id = response.json()["id"]
|
||||||
|
|
||||||
|
# check participants speakers
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
||||||
|
assert response.status_code == 200
|
||||||
|
participants = response.json()
|
||||||
|
assert len(participants) == 2
|
||||||
|
assert participants[0]["name"] == "Participant 1"
|
||||||
|
assert participants[0]["speaker"] is None
|
||||||
|
assert participants[1]["name"] == "Participant 2"
|
||||||
|
assert participants[1]["speaker"] is None
|
||||||
|
|
||||||
|
# check initial topics of the transcript
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
# check through segments
|
||||||
|
assert len(topics[0]["segments"]) == 1
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 0
|
||||||
|
assert len(topics[1]["segments"]) == 1
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
|
# reassign speaker from a participant
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"participant": participant1_id,
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check participants if speaker has been assigned
|
||||||
|
# first participant should have 1, because it's not used yet.
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
||||||
|
assert response.status_code == 200
|
||||||
|
participants = response.json()
|
||||||
|
assert len(participants) == 2
|
||||||
|
assert participants[0]["name"] == "Participant 1"
|
||||||
|
assert participants[0]["speaker"] == 1
|
||||||
|
assert participants[1]["name"] == "Participant 2"
|
||||||
|
assert participants[1]["speaker"] is None
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
# check segments
|
||||||
|
assert len(topics[0]["segments"]) == 1
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
|
assert len(topics[1]["segments"]) == 1
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
|
# reassign participant, middle of 2 topics
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"participant": participant2_id,
|
||||||
|
"timestamp_from": 1,
|
||||||
|
"timestamp_to": 2.5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check participants if speaker has been assigned
|
||||||
|
# first participant should have 1, because it's not used yet.
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/participants")
|
||||||
|
assert response.status_code == 200
|
||||||
|
participants = response.json()
|
||||||
|
assert len(participants) == 2
|
||||||
|
assert participants[0]["name"] == "Participant 1"
|
||||||
|
assert participants[0]["speaker"] == 1
|
||||||
|
assert participants[1]["name"] == "Participant 2"
|
||||||
|
assert participants[1]["speaker"] == 2
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 2
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 2
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
# check segments
|
||||||
|
assert len(topics[0]["segments"]) == 2
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["segments"][1]["speaker"] == 2
|
||||||
|
assert len(topics[1]["segments"]) == 2
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 2
|
||||||
|
assert topics[1]["segments"][1]["speaker"] == 0
|
||||||
|
|
||||||
|
# reassign speaker, everything
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"participant": participant1_id,
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 1
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 1
|
||||||
|
# check segments
|
||||||
|
assert len(topics[0]["segments"]) == 1
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
|
assert len(topics[1]["segments"]) == 1
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
# check the transcript exists
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
transcript = response.json()
|
||||||
|
assert len(transcript["participants"]) == 0
|
||||||
|
|
||||||
|
# try reassign without any participant_id or speaker
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
# try reassing with both participant_id and speaker
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"participant": "123",
|
||||||
|
"speaker": 1,
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
# try reassing with non-existing participant_id
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"participant": "123",
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|||||||
Reference in New Issue
Block a user