mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: add API to reassign speakers, and get topics with words
This commit is contained in:
@@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router
|
|||||||
from reflector.views.transcripts_participants import (
|
from reflector.views.transcripts_participants import (
|
||||||
router as transcripts_participants_router,
|
router as transcripts_participants_router,
|
||||||
)
|
)
|
||||||
|
from reflector.views.transcripts_speaker import router as transcripts_speaker_router
|
||||||
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
||||||
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
||||||
from reflector.views.user import router as user_router
|
from reflector.views.user import router as user_router
|
||||||
@@ -68,6 +69,7 @@ app.include_router(rtc_offer_router)
|
|||||||
app.include_router(transcripts_router, prefix="/v1")
|
app.include_router(transcripts_router, prefix="/v1")
|
||||||
app.include_router(transcripts_audio_router, prefix="/v1")
|
app.include_router(transcripts_audio_router, prefix="/v1")
|
||||||
app.include_router(transcripts_participants_router, prefix="/v1")
|
app.include_router(transcripts_participants_router, prefix="/v1")
|
||||||
|
app.include_router(transcripts_speaker_router, prefix="/v1")
|
||||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||||
app.include_router(user_router, prefix="/v1")
|
app.include_router(user_router, prefix="/v1")
|
||||||
|
|||||||
@@ -362,7 +362,7 @@ class TranscriptController:
|
|||||||
await database.execute(query)
|
await database.execute(query)
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
async def update(self, transcript: Transcript, values: dict):
|
async def update(self, transcript: Transcript, values: dict, mutate=True):
|
||||||
"""
|
"""
|
||||||
Update a transcript fields with key/values in values
|
Update a transcript fields with key/values in values
|
||||||
"""
|
"""
|
||||||
@@ -372,6 +372,7 @@ class TranscriptController:
|
|||||||
.values(**values)
|
.values(**values)
|
||||||
)
|
)
|
||||||
await database.execute(query)
|
await database.execute(query)
|
||||||
|
if mutate:
|
||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(transcript, key, value)
|
setattr(transcript, key, value)
|
||||||
|
|
||||||
@@ -410,7 +411,11 @@ class TranscriptController:
|
|||||||
Append an event to a transcript
|
Append an event to a transcript
|
||||||
"""
|
"""
|
||||||
resp = transcript.add_event(event=event, data=data)
|
resp = transcript.add_event(event=event, data=data)
|
||||||
await self.update(transcript, {"events": transcript.events_dump()})
|
await self.update(
|
||||||
|
transcript,
|
||||||
|
{"events": transcript.events_dump()},
|
||||||
|
mutate=False,
|
||||||
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def upsert_topic(
|
async def upsert_topic(
|
||||||
@@ -422,7 +427,11 @@ class TranscriptController:
|
|||||||
Append an event to a transcript
|
Append an event to a transcript
|
||||||
"""
|
"""
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
await self.update(
|
||||||
|
transcript,
|
||||||
|
{"topics": transcript.topics_dump()},
|
||||||
|
mutate=False,
|
||||||
|
)
|
||||||
|
|
||||||
async def move_mp3_to_storage(self, transcript: Transcript):
|
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||||
"""
|
"""
|
||||||
@@ -450,7 +459,11 @@ class TranscriptController:
|
|||||||
Add/update a participant to a transcript
|
Add/update a participant to a transcript
|
||||||
"""
|
"""
|
||||||
result = transcript.upsert_participant(participant)
|
result = transcript.upsert_participant(participant)
|
||||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
await self.update(
|
||||||
|
transcript,
|
||||||
|
{"participants": transcript.participants_dump()},
|
||||||
|
mutate=False,
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_participant(
|
async def delete_participant(
|
||||||
@@ -462,7 +475,11 @@ class TranscriptController:
|
|||||||
Delete a participant from a transcript
|
Delete a participant from a transcript
|
||||||
"""
|
"""
|
||||||
transcript.delete_participant(participant_id)
|
transcript.delete_participant(participant_id)
|
||||||
await self.update(transcript, {"participants": transcript.participants_dump()})
|
await self.update(
|
||||||
|
transcript,
|
||||||
|
{"participants": transcript.participants_dump()},
|
||||||
|
mutate=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
transcripts_controller = TranscriptController()
|
transcripts_controller = TranscriptController()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from reflector.db.transcripts import (
|
|||||||
transcripts_controller,
|
transcripts_controller,
|
||||||
)
|
)
|
||||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||||
|
from reflector.processors.types import Word
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -159,6 +160,17 @@ class GetTranscriptTopic(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GetTranscriptTopicWithWords(GetTranscriptTopic):
|
||||||
|
words: list[Word] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_transcript_topic(cls, topic: TranscriptTopic):
|
||||||
|
instance = super().from_transcript_topic(topic)
|
||||||
|
if topic.words:
|
||||||
|
instance.words = topic.words
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||||
async def transcript_get(
|
async def transcript_get(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
@@ -215,3 +227,23 @@ async def transcript_get_topics(
|
|||||||
return [
|
return [
|
||||||
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
|
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/transcripts/{transcript_id}/topics/with-words",
|
||||||
|
response_model=list[GetTranscriptTopicWithWords],
|
||||||
|
)
|
||||||
|
async def transcript_get_topics_with_words(
|
||||||
|
transcript_id: str,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
):
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
|
transcript_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert to GetTranscriptTopicWithWords
|
||||||
|
return [
|
||||||
|
GetTranscriptTopicWithWords.from_transcript_topic(topic)
|
||||||
|
for topic in transcript.topics
|
||||||
|
]
|
||||||
|
|||||||
63
server/reflector/views/transcripts_speaker.py
Normal file
63
server/reflector/views/transcripts_speaker.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Reassign speakers in a transcript
|
||||||
|
=================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
import reflector.auth as auth
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerAssignment(BaseModel):
|
||||||
|
speaker: int
|
||||||
|
timestamp_from: float
|
||||||
|
timestamp_to: float
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerAssignmentStatus(BaseModel):
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/transcripts/{transcript_id}/speaker/assign")
|
||||||
|
async def transcript_assign_speaker(
|
||||||
|
transcript_id: str,
|
||||||
|
assignment: SpeakerAssignment,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
) -> SpeakerAssignmentStatus:
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
|
transcript_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not transcript:
|
||||||
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
|
# reassign speakers from words in the transcript
|
||||||
|
ts_from = assignment.timestamp_from
|
||||||
|
ts_to = assignment.timestamp_to
|
||||||
|
changed_topics = []
|
||||||
|
for topic in transcript.topics:
|
||||||
|
changed = False
|
||||||
|
for word in topic.words:
|
||||||
|
if ts_from <= word.start <= ts_to:
|
||||||
|
word.speaker = assignment.speaker
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
changed_topics.append(topic)
|
||||||
|
|
||||||
|
# batch changes
|
||||||
|
for topic in changed_topics:
|
||||||
|
transcript.upsert_topic(topic)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"topics": transcript.topics_dump(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return SpeakerAssignmentStatus(status="ok")
|
||||||
@@ -36,7 +36,13 @@ def dummy_processors():
|
|||||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||||
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
|
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
|
||||||
mock_translate.return_value = "Bonjour le monde"
|
mock_translate.return_value = "Bonjour le monde"
|
||||||
yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa
|
yield (
|
||||||
|
mock_translate,
|
||||||
|
mock_topic,
|
||||||
|
mock_title,
|
||||||
|
mock_long_summary,
|
||||||
|
mock_short_summary,
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -166,3 +172,64 @@ def fake_mp3_upload():
|
|||||||
) as mock_move:
|
) as mock_move:
|
||||||
mock_move.return_value = True
|
mock_move.return_value = True
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def fake_transcript_with_topics(tmpdir):
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.views.transcripts import transcripts_controller
|
||||||
|
from reflector.db.transcripts import TranscriptTopic
|
||||||
|
from reflector.processors.types import Word
|
||||||
|
from pathlib import Path
|
||||||
|
from httpx import AsyncClient
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
settings.DATA_DIR = Path(tmpdir)
|
||||||
|
|
||||||
|
# create a transcript
|
||||||
|
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||||
|
response = await ac.post("/transcripts", json={"name": "Test audio download"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(tid)
|
||||||
|
assert transcript is not None
|
||||||
|
|
||||||
|
await transcripts_controller.update(transcript, {"status": "finished"})
|
||||||
|
|
||||||
|
# manually copy a file at the expected location
|
||||||
|
audio_filename = transcript.audio_mp3_filename
|
||||||
|
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||||
|
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(path, audio_filename)
|
||||||
|
|
||||||
|
# create some topics
|
||||||
|
await transcripts_controller.upsert_topic(
|
||||||
|
transcript,
|
||||||
|
TranscriptTopic(
|
||||||
|
title="Topic 1",
|
||||||
|
summary="Topic 1 summary",
|
||||||
|
timestamp=0,
|
||||||
|
transcript="Hello world",
|
||||||
|
words=[
|
||||||
|
Word(text="Hello", start=0, end=1, speaker=0),
|
||||||
|
Word(text="world", start=1, end=2, speaker=0),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await transcripts_controller.upsert_topic(
|
||||||
|
transcript,
|
||||||
|
TranscriptTopic(
|
||||||
|
title="Topic 2",
|
||||||
|
summary="Topic 2 summary",
|
||||||
|
timestamp=2,
|
||||||
|
transcript="Hello world",
|
||||||
|
words=[
|
||||||
|
Word(text="Hello", start=2, end=3, speaker=0),
|
||||||
|
Word(text="world", start=3, end=4, speaker=0),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield transcript
|
||||||
|
|||||||
117
server/tests/test_transcripts_speaker.py
Normal file
117
server/tests/test_transcripts_speaker.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_reassign_speaker(fake_transcript_with_topics):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
# check the transcript exists
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check initial topics of the transcript
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
# check through segments
|
||||||
|
assert len(topics[0]["segments"]) == 1
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 0
|
||||||
|
assert len(topics[1]["segments"]) == 1
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
|
# reassign speaker
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"speaker": 1,
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 1
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 0
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
# check segments
|
||||||
|
assert len(topics[0]["segments"]) == 1
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
|
assert len(topics[1]["segments"]) == 1
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 0
|
||||||
|
|
||||||
|
# reassign speaker, middle of 2 topics
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"speaker": 2,
|
||||||
|
"timestamp_from": 1,
|
||||||
|
"timestamp_to": 2.5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 2
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 2
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 0
|
||||||
|
# check segments
|
||||||
|
assert len(topics[0]["segments"]) == 2
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 1
|
||||||
|
assert topics[0]["segments"][1]["speaker"] == 2
|
||||||
|
assert len(topics[1]["segments"]) == 2
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 2
|
||||||
|
assert topics[1]["segments"][1]["speaker"] == 0
|
||||||
|
|
||||||
|
# reassign speaker, everything
|
||||||
|
response = await ac.patch(
|
||||||
|
f"/transcripts/{transcript_id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"speaker": 4,
|
||||||
|
"timestamp_from": 0,
|
||||||
|
"timestamp_to": 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# check topics again
|
||||||
|
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
|
||||||
|
assert response.status_code == 200
|
||||||
|
topics = response.json()
|
||||||
|
assert len(topics) == 2
|
||||||
|
|
||||||
|
# check through words
|
||||||
|
assert topics[0]["words"][0]["speaker"] == 4
|
||||||
|
assert topics[0]["words"][1]["speaker"] == 4
|
||||||
|
assert topics[1]["words"][0]["speaker"] == 4
|
||||||
|
assert topics[1]["words"][1]["speaker"] == 4
|
||||||
|
# check segments
|
||||||
|
assert len(topics[0]["segments"]) == 1
|
||||||
|
assert topics[0]["segments"][0]["speaker"] == 4
|
||||||
|
assert len(topics[1]["segments"]) == 1
|
||||||
|
assert topics[1]["segments"][0]["speaker"] == 4
|
||||||
Reference in New Issue
Block a user