From 6f3d7df507d900d90b8d3b5b8c018c04fd7c27ca Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 6 Dec 2023 16:41:18 +0100 Subject: [PATCH] server: add API to reassign speakers, and get topics with words --- server/reflector/app.py | 2 + server/reflector/db/transcripts.py | 31 +++-- server/reflector/views/transcripts.py | 32 +++++ server/reflector/views/transcripts_speaker.py | 63 ++++++++++ server/tests/conftest.py | 69 ++++++++++- server/tests/test_transcripts_speaker.py | 117 ++++++++++++++++++ 6 files changed, 306 insertions(+), 8 deletions(-) create mode 100644 server/reflector/views/transcripts_speaker.py create mode 100644 server/tests/test_transcripts_speaker.py diff --git a/server/reflector/app.py b/server/reflector/app.py index 8f45efd5..14c167a2 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router from reflector.views.transcripts_participants import ( router as transcripts_participants_router, ) +from reflector.views.transcripts_speaker import router as transcripts_speaker_router from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.user import router as user_router @@ -68,6 +69,7 @@ app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") app.include_router(transcripts_audio_router, prefix="/v1") app.include_router(transcripts_participants_router, prefix="/v1") +app.include_router(transcripts_speaker_router, prefix="/v1") app.include_router(transcripts_websocket_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(user_router, prefix="/v1") diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 970393d5..6550af05 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -362,7 +362,7 @@ class TranscriptController: await database.execute(query) return transcript - async def update(self, transcript: Transcript, values: dict): + async def update(self, transcript: Transcript, values: dict, mutate=True): """ Update a transcript fields with key/values in values """ @@ -372,8 +372,9 @@ class TranscriptController: .values(**values) ) await database.execute(query) - for key, value in values.items(): - setattr(transcript, key, value) + if mutate: + for key, value in values.items(): + setattr(transcript, key, value) async def remove_by_id( self, @@ -410,7 +411,11 @@ class TranscriptController: Append an event to a transcript """ resp = transcript.add_event(event=event, data=data) - await self.update(transcript, {"events": transcript.events_dump()}) + await self.update( + transcript, + {"events": transcript.events_dump()}, + mutate=False, + ) return resp async def upsert_topic( @@ -422,7 +427,11 @@ class TranscriptController: Append an event to a transcript """ transcript.upsert_topic(topic) - await self.update(transcript, {"topics": transcript.topics_dump()}) + await self.update( + transcript, + {"topics": transcript.topics_dump()}, + mutate=False, + ) async def move_mp3_to_storage(self, transcript: Transcript): """ @@ -450,7 +459,11 @@ class TranscriptController: Add/update a participant to a transcript """ result = transcript.upsert_participant(participant) - await self.update(transcript, {"participants": transcript.participants_dump()}) + await self.update( + transcript, + {"participants": transcript.participants_dump()}, + mutate=False, + ) return result async def delete_participant( @@ -462,7 +475,11 @@ class TranscriptController: Delete a participant from a transcript """ transcript.delete_participant(participant_id) - await self.update(transcript, {"participants": transcript.participants_dump()}) + await self.update( + transcript, + {"participants": transcript.participants_dump()}, + mutate=False, + ) transcripts_controller = TranscriptController() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 9e62192b..682e1576 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -13,6 +13,7 @@ from reflector.db.transcripts import ( transcripts_controller, ) from reflector.processors.types import Transcript as ProcessorTranscript +from reflector.processors.types import Word from reflector.settings import settings router = APIRouter() @@ -159,6 +160,17 @@ class GetTranscriptTopic(BaseModel): ) +class GetTranscriptTopicWithWords(GetTranscriptTopic): + words: list[Word] = [] + + @classmethod + def from_transcript_topic(cls, topic: TranscriptTopic): + instance = super().from_transcript_topic(topic) + if topic.words: + instance.words = topic.words + return instance + + @router.get("/transcripts/{transcript_id}", response_model=GetTranscript) async def transcript_get( transcript_id: str, @@ -215,3 +227,23 @@ async def transcript_get_topics( return [ GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics ] + + +@router.get( + "/transcripts/{transcript_id}/topics/with-words", + response_model=list[GetTranscriptTopicWithWords], +) +async def transcript_get_topics_with_words( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # convert to GetTranscriptTopicWithWords + return [ + GetTranscriptTopicWithWords.from_transcript_topic(topic) + for topic in transcript.topics + ] diff --git a/server/reflector/views/transcripts_speaker.py b/server/reflector/views/transcripts_speaker.py new file mode 100644 index 00000000..20489aa0 --- /dev/null +++ b/server/reflector/views/transcripts_speaker.py @@ -0,0 +1,63 @@ +""" +Reassign speakers in a transcript +================================= + +""" +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from reflector.db.transcripts import transcripts_controller + +router = APIRouter() + + +class SpeakerAssignment(BaseModel): + speaker: int + timestamp_from: float + timestamp_to: float + + +class SpeakerAssignmentStatus(BaseModel): + status: str + + +@router.patch("/transcripts/{transcript_id}/speaker/assign") +async def transcript_assign_speaker( + transcript_id: str, + assignment: SpeakerAssignment, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> SpeakerAssignmentStatus: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + # reassign speakers from words in the transcript + ts_from = assignment.timestamp_from + ts_to = assignment.timestamp_to + changed_topics = [] + for topic in transcript.topics: + changed = False + for word in topic.words: + if ts_from <= word.start <= ts_to: + word.speaker = assignment.speaker + changed = True + if changed: + changed_topics.append(topic) + + # batch changes + for topic in changed_topics: + transcript.upsert_topic(topic) + await transcripts_controller.update( + transcript, + { + "topics": transcript.topics_dump(), + }, + ) + + return SpeakerAssignmentStatus(status="ok") diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 532ebff9..1a7b1714 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -36,7 +36,13 @@ def dummy_processors(): mock_long_summary.return_value = "LLM LONG SUMMARY" mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"} mock_translate.return_value = "Bonjour le monde" - yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa + yield ( + mock_translate, + mock_topic, + mock_title, + mock_long_summary, + mock_short_summary, + ) # noqa @pytest.fixture @@ -166,3 +172,64 @@ def fake_mp3_upload(): ) as mock_move: mock_move.return_value = True yield + + +@pytest.fixture +async def fake_transcript_with_topics(tmpdir): + from reflector.settings import settings + from reflector.app import app + from reflector.views.transcripts import transcripts_controller + from reflector.db.transcripts import TranscriptTopic + from reflector.processors.types import Word + from pathlib import Path + from httpx import AsyncClient + import shutil + + settings.DATA_DIR = Path(tmpdir) + + # create a transcript + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.post("/transcripts", json={"name": "Test audio download"}) + assert response.status_code == 200 + tid = response.json()["id"] + + transcript = await transcripts_controller.get_by_id(tid) + assert transcript is not None + + await transcripts_controller.update(transcript, {"status": "finished"}) + + # manually copy a file at the expected location + audio_filename = transcript.audio_mp3_filename + path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3" + audio_filename.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(path, audio_filename) + + # create some topics + await transcripts_controller.upsert_topic( + transcript, + TranscriptTopic( + title="Topic 1", + summary="Topic 1 summary", + timestamp=0, + transcript="Hello world", + words=[ + Word(text="Hello", start=0, end=1, speaker=0), + Word(text="world", start=1, end=2, speaker=0), + ], + ), + ) + await transcripts_controller.upsert_topic( + transcript, + TranscriptTopic( + title="Topic 2", + summary="Topic 2 summary", + timestamp=2, + transcript="Hello world", + words=[ + Word(text="Hello", start=2, end=3, speaker=0), + Word(text="world", start=3, end=4, speaker=0), + ], + ), + ) + + yield transcript diff --git a/server/tests/test_transcripts_speaker.py b/server/tests/test_transcripts_speaker.py new file mode 100644 index 00000000..2bca0beb --- /dev/null +++ b/server/tests/test_transcripts_speaker.py @@ -0,0 +1,117 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_transcript_reassign_speaker(fake_transcript_with_topics): + from reflector.app import app + + transcript_id = fake_transcript_with_topics.id + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + # check the transcript exists + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + + # check initial topics of the transcript + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 0 + assert topics[0]["words"][1]["speaker"] == 0 + assert topics[1]["words"][0]["speaker"] == 0 + assert topics[1]["words"][1]["speaker"] == 0 + # check through segments + assert len(topics[0]["segments"]) == 1 + assert topics[0]["segments"][0]["speaker"] == 0 + assert len(topics[1]["segments"]) == 1 + assert topics[1]["segments"][0]["speaker"] == 0 + + # reassign speaker + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "speaker": 1, + "timestamp_from": 0, + "timestamp_to": 1, + }, + ) + assert response.status_code == 200 + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 1 + assert topics[0]["words"][1]["speaker"] == 1 + assert topics[1]["words"][0]["speaker"] == 0 + assert topics[1]["words"][1]["speaker"] == 0 + # check segments + assert len(topics[0]["segments"]) == 1 + assert topics[0]["segments"][0]["speaker"] == 1 + assert len(topics[1]["segments"]) == 1 + assert topics[1]["segments"][0]["speaker"] == 0 + + # reassign speaker, middle of 2 topics + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "speaker": 2, + "timestamp_from": 1, + "timestamp_to": 2.5, + }, + ) + assert response.status_code == 200 + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 1 + assert topics[0]["words"][1]["speaker"] == 2 + assert topics[1]["words"][0]["speaker"] == 2 + assert topics[1]["words"][1]["speaker"] == 0 + # check segments + assert len(topics[0]["segments"]) == 2 + assert topics[0]["segments"][0]["speaker"] == 1 + assert topics[0]["segments"][1]["speaker"] == 2 + assert len(topics[1]["segments"]) == 2 + assert topics[1]["segments"][0]["speaker"] == 2 + assert topics[1]["segments"][1]["speaker"] == 0 + + # reassign speaker, everything + response = await ac.patch( + f"/transcripts/{transcript_id}/speaker/assign", + json={ + "speaker": 4, + "timestamp_from": 0, + "timestamp_to": 100, + }, + ) + assert response.status_code == 200 + + # check topics again + response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words") + assert response.status_code == 200 + topics = response.json() + assert len(topics) == 2 + + # check through words + assert topics[0]["words"][0]["speaker"] == 4 + assert topics[0]["words"][1]["speaker"] == 4 + assert topics[1]["words"][0]["speaker"] == 4 + assert topics[1]["words"][1]["speaker"] == 4 + # check segments + assert len(topics[0]["segments"]) == 1 + assert topics[0]["segments"][0]["speaker"] == 4 + assert len(topics[1]["segments"]) == 1 + assert topics[1]["segments"][0]["speaker"] == 4