Merge pull request #330 from Monadical-SAS/feat-participants

Participant API
This commit is contained in:
2023-12-01 14:23:55 +01:00
committed by GitHub
10 changed files with 610 additions and 202 deletions

View File

@@ -0,0 +1,30 @@
"""participants
Revision ID: 125031f7cb78
Revises: 0fea6d96b096
Create Date: 2023-11-30 15:56:03.341466
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '125031f7cb78'
down_revision: Union[str, None] = '0fea6d96b096'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('transcript', 'participants')
# ### end Alembic commands ###

View File

@@ -13,6 +13,12 @@ from reflector.metrics import metrics_init
from reflector.settings import settings from reflector.settings import settings
from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router from reflector.views.transcripts import router as transcripts_router
from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_participants import (
router as transcripts_participants_router,
)
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router from reflector.views.user import router as user_router
try: try:
@@ -60,6 +66,10 @@ metrics_init(app, instrumentator)
# register views # register views
app.include_router(rtc_offer_router) app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1") app.include_router(transcripts_router, prefix="/v1")
app.include_router(transcripts_audio_router, prefix="/v1")
app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")
add_pagination(app) add_pagination(app)

View File

@@ -7,7 +7,7 @@ from uuid import uuid4
import sqlalchemy import sqlalchemy
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from reflector.db import database, metadata from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings from reflector.settings import settings
@@ -27,6 +27,7 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column( sqlalchemy.Column(
@@ -112,6 +113,13 @@ class TranscriptEvent(BaseModel):
data: dict data: dict
class TranscriptParticipant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
speaker: int | None
name: str
class Transcript(BaseModel): class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None user_id: str | None = None
@@ -125,6 +133,7 @@ class Transcript(BaseModel):
long_summary: str | None = None long_summary: str | None = None
topics: list[TranscriptTopic] = [] topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = [] events: list[TranscriptEvent] = []
participants: list[TranscriptParticipant] = []
source_language: str = "en" source_language: str = "en"
target_language: str = "en" target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private" share_mode: Literal["private", "semi-private", "public"] = "private"
@@ -142,12 +151,34 @@ class Transcript(BaseModel):
else: else:
self.topics.append(topic) self.topics.append(topic)
def upsert_participant(self, participant: TranscriptParticipant):
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant.id),
None,
)
if index is not None:
self.participants[index] = participant
else:
self.participants.append(participant)
return participant
def delete_participant(self, participant_id: str):
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant_id),
None,
)
if index is not None:
del self.participants[index]
def events_dump(self, mode="json"): def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events] return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"): def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics] return [topic.model_dump(mode=mode) for topic in self.topics]
def participants_dump(self, mode="json"):
return [participant.model_dump(mode=mode) for participant in self.participants]
def unlink(self): def unlink(self):
self.data_path.unlink(missing_ok=True) self.data_path.unlink(missing_ok=True)
@@ -410,5 +441,28 @@ class TranscriptController:
# unlink the local file # unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True) transcript.audio_mp3_filename.unlink(missing_ok=True)
async def upsert_participant(
self,
transcript: Transcript,
participant: TranscriptParticipant,
) -> TranscriptParticipant:
"""
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(transcript, {"participants": transcript.participants_dump()})
return result
async def delete_participant(
self,
transcript: Transcript,
participant_id: str,
):
"""
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(transcript, {"participants": transcript.participants_dump()})
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -1,33 +1,19 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Annotated, Literal, Optional from typing import Annotated, Literal, Optional
import httpx
import reflector.auth as auth import reflector.auth as auth
from fastapi import ( from fastapi import APIRouter, Depends, HTTPException
APIRouter,
Depends,
HTTPException,
Request,
Response,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi_pagination import Page from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate from fastapi_pagination.ext.databases import paginate
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from reflector.db.transcripts import ( from reflector.db.transcripts import (
AudioWaveform, TranscriptParticipant,
TranscriptTopic, TranscriptTopic,
transcripts_controller, transcripts_controller,
) )
from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter() router = APIRouter()
@@ -62,6 +48,7 @@ class GetTranscript(BaseModel):
share_mode: str = Field("private") share_mode: str = Field("private")
source_language: str | None source_language: str | None
target_language: str | None target_language: str | None
participants: list[TranscriptParticipant] | None
class CreateTranscript(BaseModel): class CreateTranscript(BaseModel):
@@ -77,6 +64,7 @@ class UpdateTranscript(BaseModel):
short_summary: Optional[str] = Field(None) short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None) long_summary: Optional[str] = Field(None)
share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None) share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
participants: Optional[list[TranscriptParticipant]] = Field(None)
class DeletionStatus(BaseModel): class DeletionStatus(BaseModel):
@@ -192,19 +180,7 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
values = {} values = info.dict(exclude_unset=True)
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.long_summary is not None:
values["long_summary"] = info.long_summary
if info.short_summary is not None:
values["short_summary"] = info.short_summary
if info.title is not None:
values["title"] = info.title
if info.share_mode is not None:
values["share_mode"] = info.share_mode
await transcripts_controller.update(transcript, values) await transcripts_controller.update(transcript, values)
return transcript return transcript
@@ -222,97 +198,6 @@ async def transcript_delete(
return DeletionStatus(status="ok") return DeletionStatus(status="ok")
@router.get("/transcripts/{transcript_id}/audio/mp3")
@router.head("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform
@router.get( @router.get(
"/transcripts/{transcript_id}/topics", "/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic], response_model=list[GetTranscriptTopic],
@@ -330,84 +215,3 @@ async def transcript_get_topics(
return [ return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
] ]
# ==============================================================
# Websocket
# ==============================================================
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
try:
# on first connection, send all events only to the current user
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
await ws_manager.remove_user_from_room(room_id, websocket)
# ==============================================================
# Web RTC
# ==============================================================
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
pipeline_runner=pipeline_runner,
)

View File

@@ -0,0 +1,109 @@
"""
Transcripts audio related endpoints
===================================
"""
from typing import Annotated, Optional
import httpx
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt
from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM
from ._range_requests_response import range_requests_response
router = APIRouter()
@router.get("/transcripts/{transcript_id}/audio/mp3")
@router.head("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform

View File

@@ -0,0 +1,142 @@
"""
Transcript participants API endpoints
=====================================
"""
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus
router = APIRouter()
class Participant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
speaker: int | None
name: str
class CreateParticipant(BaseModel):
speaker: Optional[int] = Field(None)
name: str
class UpdateParticipant(BaseModel):
speaker: Optional[int] = Field(None)
name: Optional[str] = Field(None)
@router.get("/transcripts/{transcript_id}/participants")
async def transcript_get_participants(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> list[Participant]:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
return [
Participant.model_validate(participant)
for participant in transcript.participants
]
@router.post("/transcripts/{transcript_id}/participants")
async def transcript_add_participant(
transcript_id: str,
participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# ensure the speaker is unique
for p in transcript.participants:
if p.speaker == participant.speaker:
raise HTTPException(
status_code=400,
detail="Speaker already assigned",
)
obj = await transcripts_controller.upsert_participant(
transcript, TranscriptParticipant(**participant.dict())
)
return Participant.model_validate(obj)
@router.get("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_get_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
for p in transcript.participants:
if p.id == participant_id:
return Participant.model_validate(p)
raise HTTPException(status_code=404, detail="Participant not found")
@router.patch("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_update_participant(
transcript_id: str,
participant_id: str,
participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# ensure the speaker is unique
for p in transcript.participants:
if p.speaker == participant.speaker and p.id != participant_id:
raise HTTPException(
status_code=400,
detail="Speaker already assigned",
)
# find the participant
obj = None
for p in transcript.participants:
if p.id == participant_id:
obj = p
break
if not obj:
raise HTTPException(status_code=404, detail="Participant not found")
# update participant but just the fields that are set
fields = participant.dict(exclude_unset=True)
obj = obj.copy(update=fields)
await transcripts_controller.upsert_participant(transcript, obj)
return Participant.model_validate(obj)
@router.delete("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_delete_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> DeletionStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok")

View File

@@ -0,0 +1,37 @@
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, Request
from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
pipeline_runner=pipeline_runner,
)

View File

@@ -0,0 +1,53 @@
"""
Transcripts websocket API
=========================
"""
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from reflector.db.transcripts import transcripts_controller
from reflector.ws_manager import get_ws_manager
router = APIRouter()
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
try:
# on first connection, send all events only to the current user
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
await ws_manager.remove_user_from_room(room_id, websocket)

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class DeletionStatus(BaseModel):
status: str

View File

@@ -0,0 +1,164 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_participants():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
# create a participant
transcript_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] is None
assert response.json()["name"] == "test"
# create another one with a speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] == 1
assert response.json()["name"] == "test2"
# get all participants via transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 2
# get participants via participants endpoint
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
assert len(response.json()) == 2
@pytest.mark.asyncio
async def test_transcript_participants_same_speaker():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create another one with the same speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_transcript_participants_update_name():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# update the participant
participant_id = response.json()["id"]
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"},
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated in transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_participants_update_speaker():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create another participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# update the participant, refused as speaker is already taken
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 400
# delete the participant 1
response = await ac.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}"
)
assert response.status_code == 200
# update the participant 2 again, should be accepted now
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 200
# ensure participant2 name is still there
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
assert response.json()["speaker"] == 1