From 7ac6d2521737248f81aa0cd31483fbbe8216cedd Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 30 Nov 2023 17:30:08 +0100 Subject: [PATCH] server: add participant API Also break out views into different files for easier reading --- .../versions/125031f7cb78_participants.py | 30 +++ server/reflector/app.py | 10 + server/reflector/db/transcripts.py | 56 ++++- server/reflector/views/transcripts.py | 206 +----------------- server/reflector/views/transcripts_audio.py | 109 +++++++++ .../views/transcripts_participants.py | 142 ++++++++++++ server/reflector/views/transcripts_webrtc.py | 37 ++++ .../reflector/views/transcripts_websocket.py | 53 +++++ server/reflector/views/types.py | 5 + server/tests/test_transcripts_participants.py | 164 ++++++++++++++ 10 files changed, 610 insertions(+), 202 deletions(-) create mode 100644 server/migrations/versions/125031f7cb78_participants.py create mode 100644 server/reflector/views/transcripts_audio.py create mode 100644 server/reflector/views/transcripts_participants.py create mode 100644 server/reflector/views/transcripts_webrtc.py create mode 100644 server/reflector/views/transcripts_websocket.py create mode 100644 server/reflector/views/types.py create mode 100644 server/tests/test_transcripts_participants.py diff --git a/server/migrations/versions/125031f7cb78_participants.py b/server/migrations/versions/125031f7cb78_participants.py new file mode 100644 index 00000000..c345b083 --- /dev/null +++ b/server/migrations/versions/125031f7cb78_participants.py @@ -0,0 +1,30 @@ +"""participants + +Revision ID: 125031f7cb78 +Revises: 0fea6d96b096 +Create Date: 2023-11-30 15:56:03.341466 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '125031f7cb78' +down_revision: Union[str, None] = '0fea6d96b096' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('transcript', 'participants') + # ### end Alembic commands ### diff --git a/server/reflector/app.py b/server/reflector/app.py index 5bfffeca..8f45efd5 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -13,6 +13,12 @@ from reflector.metrics import metrics_init from reflector.settings import settings from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.transcripts import router as transcripts_router +from reflector.views.transcripts_audio import router as transcripts_audio_router +from reflector.views.transcripts_participants import ( + router as transcripts_participants_router, +) +from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router +from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.user import router as user_router try: @@ -60,6 +66,10 @@ metrics_init(app, instrumentator) # register views app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") +app.include_router(transcripts_audio_router, prefix="/v1") +app.include_router(transcripts_participants_router, prefix="/v1") +app.include_router(transcripts_websocket_router, prefix="/v1") +app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(user_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 0fba82ef..44688eaa 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -7,7 +7,7 @@ from uuid import uuid4 import sqlalchemy from fastapi import HTTPException -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings @@ -27,6 +27,7 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), + sqlalchemy.Column("participants", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), sqlalchemy.Column( @@ -112,6 +113,13 @@ class TranscriptEvent(BaseModel): data: dict +class TranscriptParticipant(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) + speaker: int | None + name: str + + class Transcript(BaseModel): id: str = Field(default_factory=generate_uuid4) user_id: str | None = None @@ -125,6 +133,7 @@ class Transcript(BaseModel): long_summary: str | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] + participants: list[TranscriptParticipant] = [] source_language: str = "en" target_language: str = "en" share_mode: Literal["private", "semi-private", "public"] = "private" @@ -142,12 +151,34 @@ class Transcript(BaseModel): else: self.topics.append(topic) + def upsert_participant(self, participant: TranscriptParticipant): + index = next( + (i for i, p in enumerate(self.participants) if p.id == participant.id), + None, + ) + if index is not None: + self.participants[index] = participant + else: + self.participants.append(participant) + return participant + + def delete_participant(self, participant_id: str): + index = next( + (i for i, p in enumerate(self.participants) if p.id == participant_id), + None, + ) + if index is not None: + del self.participants[index] + def events_dump(self, mode="json"): return [event.model_dump(mode=mode) for event in self.events] def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] + def participants_dump(self, mode="json"): + return [participant.model_dump(mode=mode) for participant in self.participants] + def unlink(self): self.data_path.unlink(missing_ok=True) @@ -410,5 +441,28 @@ class TranscriptController: # unlink the local file transcript.audio_mp3_filename.unlink(missing_ok=True) + async def upsert_participant( + self, + transcript: Transcript, + participant: TranscriptParticipant, + ) -> TranscriptParticipant: + """ + Add/update a participant to a transcript + """ + result = transcript.upsert_participant(participant) + await self.update(transcript, {"participants": transcript.participants_dump()}) + return result + + async def delete_participant( + self, + transcript: Transcript, + participant_id: str, + ): + """ + Delete a participant from a transcript + """ + transcript.delete_participant(participant_id) + await self.update(transcript, {"participants": transcript.participants_dump()}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 44b55629..9e62192b 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,33 +1,19 @@ from datetime import datetime, timedelta from typing import Annotated, Literal, Optional -import httpx import reflector.auth as auth -from fastapi import ( - APIRouter, - Depends, - HTTPException, - Request, - Response, - WebSocket, - WebSocketDisconnect, - status, -) +from fastapi import APIRouter, Depends, HTTPException from fastapi_pagination import Page from fastapi_pagination.ext.databases import paginate from jose import jwt from pydantic import BaseModel, Field from reflector.db.transcripts import ( - AudioWaveform, + TranscriptParticipant, TranscriptTopic, transcripts_controller, ) from reflector.processors.types import Transcript as ProcessorTranscript from reflector.settings import settings -from reflector.ws_manager import get_ws_manager - -from ._range_requests_response import range_requests_response -from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() @@ -62,6 +48,7 @@ class GetTranscript(BaseModel): share_mode: str = Field("private") source_language: str | None target_language: str | None + participants: list[TranscriptParticipant] | None class CreateTranscript(BaseModel): @@ -77,6 +64,7 @@ class UpdateTranscript(BaseModel): short_summary: Optional[str] = Field(None) long_summary: Optional[str] = Field(None) share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None) + participants: Optional[list[TranscriptParticipant]] = Field(None) class DeletionStatus(BaseModel): @@ -192,19 +180,7 @@ async def transcript_update( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - values = {} - if info.name is not None: - values["name"] = info.name - if info.locked is not None: - values["locked"] = info.locked - if info.long_summary is not None: - values["long_summary"] = info.long_summary - if info.short_summary is not None: - values["short_summary"] = info.short_summary - if info.title is not None: - values["title"] = info.title - if info.share_mode is not None: - values["share_mode"] = info.share_mode + values = info.dict(exclude_unset=True) await transcripts_controller.update(transcript, values) return transcript @@ -222,97 +198,6 @@ async def transcript_delete( return DeletionStatus(status="ok") -@router.get("/transcripts/{transcript_id}/audio/mp3") -@router.head("/transcripts/{transcript_id}/audio/mp3") -async def transcript_get_audio_mp3( - request: Request, - transcript_id: str, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], - token: str | None = None, -): - user_id = user["sub"] if user else None - if not user_id and token: - unauthorized_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) - user_id: str = payload.get("sub") - except jwt.JWTError: - raise unauthorized_exception - - transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id - ) - - if transcript.audio_location == "storage": - # proxy S3 file, to prevent issue with CORS - url = await transcript.get_audio_url() - headers = {} - - copy_headers = ["range", "accept-encoding"] - for header in copy_headers: - if header in request.headers: - headers[header] = request.headers[header] - - async with httpx.AsyncClient() as client: - resp = await client.request(request.method, url, headers=headers) - return Response( - content=resp.content, - status_code=resp.status_code, - headers=resp.headers, - ) - - if transcript.audio_location == "storage": - # proxy S3 file, to prevent issue with CORS - url = await transcript.get_audio_url() - headers = {} - - copy_headers = ["range", "accept-encoding"] - for header in copy_headers: - if header in request.headers: - headers[header] = request.headers[header] - - async with httpx.AsyncClient() as client: - resp = await client.request(request.method, url, headers=headers) - return Response( - content=resp.content, - status_code=resp.status_code, - headers=resp.headers, - ) - - if not transcript.audio_mp3_filename.exists(): - raise HTTPException(status_code=500, detail="Audio not found") - - truncated_id = str(transcript.id).split("-")[0] - filename = f"recording_{truncated_id}.mp3" - - return range_requests_response( - request, - transcript.audio_mp3_filename, - content_type="audio/mpeg", - content_disposition=f"attachment; filename={filename}", - ) - - -@router.get("/transcripts/{transcript_id}/audio/waveform") -async def transcript_get_audio_waveform( - transcript_id: str, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -) -> AudioWaveform: - user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id - ) - - if not transcript.audio_waveform_filename.exists(): - raise HTTPException(status_code=404, detail="Audio not found") - - return transcript.audio_waveform - - @router.get( "/transcripts/{transcript_id}/topics", response_model=list[GetTranscriptTopic], @@ -330,84 +215,3 @@ async def transcript_get_topics( return [ GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics ] - - -# ============================================================== -# Websocket -# ============================================================== - - -@router.get("/transcripts/{transcript_id}/events") -async def transcript_get_websocket_events(transcript_id: str): - pass - - -@router.websocket("/transcripts/{transcript_id}/events") -async def transcript_events_websocket( - transcript_id: str, - websocket: WebSocket, - # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): - # user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - - # connect to websocket manager - # use ts:transcript_id as room id - room_id = f"ts:{transcript_id}" - ws_manager = get_ws_manager() - await ws_manager.add_user_to_room(room_id, websocket) - - try: - # on first connection, send all events only to the current user - for event in transcript.events: - # for now, do not send TRANSCRIPT or STATUS options - theses are live event - # not necessary to be sent to the client; but keep the rest - name = event.event - if name in ("TRANSCRIPT", "STATUS"): - continue - await websocket.send_json(event.model_dump(mode="json")) - - # XXX if transcript is final (locked=True and status=ended) - # XXX send a final event to the client and close the connection - - # endless loop to wait for new events - # we do not have command system now, - while True: - await websocket.receive() - except (RuntimeError, WebSocketDisconnect): - await ws_manager.remove_user_from_room(room_id, websocket) - - -# ============================================================== -# Web RTC -# ============================================================== - - -@router.post("/transcripts/{transcript_id}/record/webrtc") -async def transcript_record_webrtc( - transcript_id: str, - params: RtcOffer, - request: Request, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): - user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id - ) - - if transcript.locked: - raise HTTPException(status_code=400, detail="Transcript is locked") - - # create a pipeline runner - from reflector.pipelines.main_live_pipeline import PipelineMainLive - - pipeline_runner = PipelineMainLive(transcript_id=transcript_id) - - # FIXME do not allow multiple recording at the same time - return await rtc_offer_base( - params, - request, - pipeline_runner=pipeline_runner, - ) diff --git a/server/reflector/views/transcripts_audio.py b/server/reflector/views/transcripts_audio.py new file mode 100644 index 00000000..a174d992 --- /dev/null +++ b/server/reflector/views/transcripts_audio.py @@ -0,0 +1,109 @@ +""" +Transcripts audio related endpoints +=================================== + +""" +from typing import Annotated, Optional + +import httpx +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from jose import jwt +from reflector.db.transcripts import AudioWaveform, transcripts_controller +from reflector.settings import settings +from reflector.views.transcripts import ALGORITHM + +from ._range_requests_response import range_requests_response + +router = APIRouter() + + +@router.get("/transcripts/{transcript_id}/audio/mp3") +@router.head("/transcripts/{transcript_id}/audio/mp3") +async def transcript_get_audio_mp3( + request: Request, + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + token: str | None = None, +): + user_id = user["sub"] if user else None + if not user_id and token: + unauthorized_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: str = payload.get("sub") + except jwt.JWTError: + raise unauthorized_exception + + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + + if not transcript.audio_mp3_filename.exists(): + raise HTTPException(status_code=500, detail="Audio not found") + + truncated_id = str(transcript.id).split("-")[0] + filename = f"recording_{truncated_id}.mp3" + + return range_requests_response( + request, + transcript.audio_mp3_filename, + content_type="audio/mpeg", + content_disposition=f"attachment; filename={filename}", + ) + + +@router.get("/transcripts/{transcript_id}/audio/waveform") +async def transcript_get_audio_waveform( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> AudioWaveform: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if not transcript.audio_waveform_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + return transcript.audio_waveform diff --git a/server/reflector/views/transcripts_participants.py b/server/reflector/views/transcripts_participants.py new file mode 100644 index 00000000..318d6018 --- /dev/null +++ b/server/reflector/views/transcripts_participants.py @@ -0,0 +1,142 @@ +""" +Transcript participants API endpoints +===================================== + +""" +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, ConfigDict, Field +from reflector.db.transcripts import TranscriptParticipant, transcripts_controller +from reflector.views.types import DeletionStatus + +router = APIRouter() + + +class Participant(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str + speaker: int | None + name: str + + +class CreateParticipant(BaseModel): + speaker: Optional[int] = Field(None) + name: str + + +class UpdateParticipant(BaseModel): + speaker: Optional[int] = Field(None) + name: Optional[str] = Field(None) + + +@router.get("/transcripts/{transcript_id}/participants") +async def transcript_get_participants( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> list[Participant]: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + return [ + Participant.model_validate(participant) + for participant in transcript.participants + ] + + +@router.post("/transcripts/{transcript_id}/participants") +async def transcript_add_participant( + transcript_id: str, + participant: CreateParticipant, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # ensure the speaker is unique + for p in transcript.participants: + if p.speaker == participant.speaker: + raise HTTPException( + status_code=400, + detail="Speaker already assigned", + ) + + obj = await transcripts_controller.upsert_participant( + transcript, TranscriptParticipant(**participant.dict()) + ) + return Participant.model_validate(obj) + + +@router.get("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_get_participant( + transcript_id: str, + participant_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + for p in transcript.participants: + if p.id == participant_id: + return Participant.model_validate(p) + + raise HTTPException(status_code=404, detail="Participant not found") + + +@router.patch("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_update_participant( + transcript_id: str, + participant_id: str, + participant: UpdateParticipant, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # ensure the speaker is unique + for p in transcript.participants: + if p.speaker == participant.speaker and p.id != participant_id: + raise HTTPException( + status_code=400, + detail="Speaker already assigned", + ) + + # find the participant + obj = None + for p in transcript.participants: + if p.id == participant_id: + obj = p + break + + if not obj: + raise HTTPException(status_code=404, detail="Participant not found") + + # update participant but just the fields that are set + fields = participant.dict(exclude_unset=True) + obj = obj.copy(update=fields) + + await transcripts_controller.upsert_participant(transcript, obj) + return Participant.model_validate(obj) + + +@router.delete("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_delete_participant( + transcript_id: str, + participant_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> DeletionStatus: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + await transcripts_controller.delete_participant(transcript, participant_id) + return DeletionStatus(status="ok") diff --git a/server/reflector/views/transcripts_webrtc.py b/server/reflector/views/transcripts_webrtc.py new file mode 100644 index 00000000..af451411 --- /dev/null +++ b/server/reflector/views/transcripts_webrtc.py @@ -0,0 +1,37 @@ +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException, Request +from reflector.db.transcripts import transcripts_controller + +from .rtc_offer import RtcOffer, rtc_offer_base + +router = APIRouter() + + +@router.post("/transcripts/{transcript_id}/record/webrtc") +async def transcript_record_webrtc( + transcript_id: str, + params: RtcOffer, + request: Request, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.locked: + raise HTTPException(status_code=400, detail="Transcript is locked") + + # create a pipeline runner + from reflector.pipelines.main_live_pipeline import PipelineMainLive + + pipeline_runner = PipelineMainLive(transcript_id=transcript_id) + + # FIXME do not allow multiple recording at the same time + return await rtc_offer_base( + params, + request, + pipeline_runner=pipeline_runner, + ) diff --git a/server/reflector/views/transcripts_websocket.py b/server/reflector/views/transcripts_websocket.py new file mode 100644 index 00000000..65571aab --- /dev/null +++ b/server/reflector/views/transcripts_websocket.py @@ -0,0 +1,53 @@ +""" +Transcripts websocket API +========================= + +""" +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect +from reflector.db.transcripts import transcripts_controller +from reflector.ws_manager import get_ws_manager + +router = APIRouter() + + +@router.get("/transcripts/{transcript_id}/events") +async def transcript_get_websocket_events(transcript_id: str): + pass + + +@router.websocket("/transcripts/{transcript_id}/events") +async def transcript_events_websocket( + transcript_id: str, + websocket: WebSocket, + # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + # user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + # connect to websocket manager + # use ts:transcript_id as room id + room_id = f"ts:{transcript_id}" + ws_manager = get_ws_manager() + await ws_manager.add_user_to_room(room_id, websocket) + + try: + # on first connection, send all events only to the current user + for event in transcript.events: + # for now, do not send TRANSCRIPT or STATUS options - theses are live event + # not necessary to be sent to the client; but keep the rest + name = event.event + if name in ("TRANSCRIPT", "STATUS"): + continue + await websocket.send_json(event.model_dump(mode="json")) + + # XXX if transcript is final (locked=True and status=ended) + # XXX send a final event to the client and close the connection + + # endless loop to wait for new events + # we do not have command system now, + while True: + await websocket.receive() + except (RuntimeError, WebSocketDisconnect): + await ws_manager.remove_user_from_room(room_id, websocket) diff --git a/server/reflector/views/types.py b/server/reflector/views/types.py new file mode 100644 index 00000000..70361131 --- /dev/null +++ b/server/reflector/views/types.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class DeletionStatus(BaseModel): + status: str diff --git a/server/tests/test_transcripts_participants.py b/server/tests/test_transcripts_participants.py new file mode 100644 index 00000000..b55b16a8 --- /dev/null +++ b/server/tests/test_transcripts_participants.py @@ -0,0 +1,164 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_transcript_participants(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + + # create a participant + transcript_id = response.json()["id"] + response = await ac.post( + f"/transcripts/{transcript_id}/participants", json={"name": "test"} + ) + assert response.status_code == 200 + assert response.json()["id"] is not None + assert response.json()["speaker"] is None + assert response.json()["name"] == "test" + + # create another one with a speaker + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["id"] is not None + assert response.json()["speaker"] == 1 + assert response.json()["name"] == "test2" + + # get all participants via transcript + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + assert len(response.json()["participants"]) == 2 + + # get participants via participants endpoint + response = await ac.get(f"/transcripts/{transcript_id}/participants") + assert response.status_code == 200 + assert len(response.json()) == 2 + + +@pytest.mark.asyncio +async def test_transcript_participants_same_speaker(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["speaker"] == 1 + + # create another one with the same speaker + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 1}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_transcript_participants_update_name(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["speaker"] == 1 + + # update the participant + participant_id = response.json()["id"] + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant_id}", + json={"name": "test2"}, + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + # verify the participant was updated + response = await ac.get( + f"/transcripts/{transcript_id}/participants/{participant_id}" + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + # verify the participant was updated in transcript + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + assert len(response.json()["participants"]) == 1 + assert response.json()["participants"][0]["name"] == "test2" + + +@pytest.mark.asyncio +async def test_transcript_participants_update_speaker(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + participant1_id = response.json()["id"] + + # create another participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 2}, + ) + assert response.status_code == 200 + participant2_id = response.json()["id"] + + # update the participant, refused as speaker is already taken + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant2_id}", + json={"speaker": 1}, + ) + assert response.status_code == 400 + + # delete the participant 1 + response = await ac.delete( + f"/transcripts/{transcript_id}/participants/{participant1_id}" + ) + assert response.status_code == 200 + + # update the participant 2 again, should be accepted now + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant2_id}", + json={"speaker": 1}, + ) + assert response.status_code == 200 + + # ensure participant2 name is still there + response = await ac.get( + f"/transcripts/{transcript_id}/participants/{participant2_id}" + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + assert response.json()["speaker"] == 1