server: add API to reassign speakers, and get topics with words

This commit is contained in:
2023-12-06 16:41:18 +01:00
parent 84a1350df7
commit 6f3d7df507
6 changed files with 306 additions and 8 deletions

View File

@@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_participants import ( from reflector.views.transcripts_participants import (
router as transcripts_participants_router, router as transcripts_participants_router,
) )
from reflector.views.transcripts_speaker import router as transcripts_speaker_router
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router from reflector.views.user import router as user_router
@@ -68,6 +69,7 @@ app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1") app.include_router(transcripts_router, prefix="/v1")
app.include_router(transcripts_audio_router, prefix="/v1") app.include_router(transcripts_audio_router, prefix="/v1")
app.include_router(transcripts_participants_router, prefix="/v1") app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_speaker_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1") app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")

View File

@@ -362,7 +362,7 @@ class TranscriptController:
await database.execute(query) await database.execute(query)
return transcript return transcript
async def update(self, transcript: Transcript, values: dict): async def update(self, transcript: Transcript, values: dict, mutate=True):
""" """
Update a transcript fields with key/values in values Update a transcript fields with key/values in values
""" """
@@ -372,6 +372,7 @@ class TranscriptController:
.values(**values) .values(**values)
) )
await database.execute(query) await database.execute(query)
if mutate:
for key, value in values.items(): for key, value in values.items():
setattr(transcript, key, value) setattr(transcript, key, value)
@@ -410,7 +411,11 @@ class TranscriptController:
Append an event to a transcript Append an event to a transcript
""" """
resp = transcript.add_event(event=event, data=data) resp = transcript.add_event(event=event, data=data)
await self.update(transcript, {"events": transcript.events_dump()}) await self.update(
transcript,
{"events": transcript.events_dump()},
mutate=False,
)
return resp return resp
async def upsert_topic( async def upsert_topic(
@@ -422,7 +427,11 @@ class TranscriptController:
Append an event to a transcript Append an event to a transcript
""" """
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()}) await self.update(
transcript,
{"topics": transcript.topics_dump()},
mutate=False,
)
async def move_mp3_to_storage(self, transcript: Transcript): async def move_mp3_to_storage(self, transcript: Transcript):
""" """
@@ -450,7 +459,11 @@ class TranscriptController:
Add/update a participant to a transcript Add/update a participant to a transcript
""" """
result = transcript.upsert_participant(participant) result = transcript.upsert_participant(participant)
await self.update(transcript, {"participants": transcript.participants_dump()}) await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
return result return result
async def delete_participant( async def delete_participant(
@@ -462,7 +475,11 @@ class TranscriptController:
Delete a participant from a transcript Delete a participant from a transcript
""" """
transcript.delete_participant(participant_id) transcript.delete_participant(participant_id)
await self.update(transcript, {"participants": transcript.participants_dump()}) await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -13,6 +13,7 @@ from reflector.db.transcripts import (
transcripts_controller, transcripts_controller,
) )
from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word
from reflector.settings import settings from reflector.settings import settings
router = APIRouter() router = APIRouter()
@@ -159,6 +160,17 @@ class GetTranscriptTopic(BaseModel):
) )
class GetTranscriptTopicWithWords(GetTranscriptTopic):
words: list[Word] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
instance = super().from_transcript_topic(topic)
if topic.words:
instance.words = topic.words
return instance
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript) @router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get( async def transcript_get(
transcript_id: str, transcript_id: str,
@@ -215,3 +227,23 @@ async def transcript_get_topics(
return [ return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
] ]
@router.get(
"/transcripts/{transcript_id}/topics/with-words",
response_model=list[GetTranscriptTopicWithWords],
)
async def transcript_get_topics_with_words(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# convert to GetTranscriptTopicWithWords
return [
GetTranscriptTopicWithWords.from_transcript_topic(topic)
for topic in transcript.topics
]

View File

@@ -0,0 +1,63 @@
"""
Reassign speakers in a transcript
=================================
"""
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from reflector.db.transcripts import transcripts_controller
router = APIRouter()
class SpeakerAssignment(BaseModel):
speaker: int
timestamp_from: float
timestamp_to: float
class SpeakerAssignmentStatus(BaseModel):
status: str
@router.patch("/transcripts/{transcript_id}/speaker/assign")
async def transcript_assign_speaker(
transcript_id: str,
assignment: SpeakerAssignment,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# reassign speakers from words in the transcript
ts_from = assignment.timestamp_from
ts_to = assignment.timestamp_to
changed_topics = []
for topic in transcript.topics:
changed = False
for word in topic.words:
if ts_from <= word.start <= ts_to:
word.speaker = assignment.speaker
changed = True
if changed:
changed_topics.append(topic)
# batch changes
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
transcript,
{
"topics": transcript.topics_dump(),
},
)
return SpeakerAssignmentStatus(status="ok")

View File

@@ -36,7 +36,13 @@ def dummy_processors():
mock_long_summary.return_value = "LLM LONG SUMMARY" mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"} mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
mock_translate.return_value = "Bonjour le monde" mock_translate.return_value = "Bonjour le monde"
yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa yield (
mock_translate,
mock_topic,
mock_title,
mock_long_summary,
mock_short_summary,
) # noqa
@pytest.fixture @pytest.fixture
@@ -166,3 +172,64 @@ def fake_mp3_upload():
) as mock_move: ) as mock_move:
mock_move.return_value = True mock_move.return_value = True
yield yield
@pytest.fixture
async def fake_transcript_with_topics(tmpdir):
from reflector.settings import settings
from reflector.app import app
from reflector.views.transcripts import transcripts_controller
from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word
from pathlib import Path
from httpx import AsyncClient
import shutil
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "finished"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename)
# create some topics
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Topic 1",
summary="Topic 1 summary",
timestamp=0,
transcript="Hello world",
words=[
Word(text="Hello", start=0, end=1, speaker=0),
Word(text="world", start=1, end=2, speaker=0),
],
),
)
await transcripts_controller.upsert_topic(
transcript,
TranscriptTopic(
title="Topic 2",
summary="Topic 2 summary",
timestamp=2,
transcript="Hello world",
words=[
Word(text="Hello", start=2, end=3, speaker=0),
Word(text="world", start=3, end=4, speaker=0),
],
),
)
yield transcript

View File

@@ -0,0 +1,117 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_reassign_speaker(fake_transcript_with_topics):
from reflector.app import app
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 2,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 4,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 4
assert topics[0]["words"][1]["speaker"] == 4
assert topics[1]["words"][0]["speaker"] == 4
assert topics[1]["words"][1]["speaker"] == 4
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4