server: implement user authentication (none by default)

This commit is contained in:
2023-08-16 17:24:05 +02:00
parent d470696a49
commit e12f9afe7b
13 changed files with 263 additions and 21 deletions

View File

@@ -4,6 +4,7 @@ from fastapi import (
Request,
WebSocket,
WebSocketDisconnect,
Depends,
)
from fastapi.responses import FileResponse
from starlette.concurrency import run_in_threadpool
@@ -14,8 +15,9 @@ from fastapi_pagination import Page, paginate
from reflector.logger import logger
from reflector.db import database, transcripts
from reflector.settings import settings
import reflector.auth as auth
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
from typing import Optional
from typing import Annotated, Optional
from pathlib import Path
from tempfile import NamedTemporaryFile
import av
@@ -60,6 +62,7 @@ class TranscriptEvent(BaseModel):
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
@@ -127,20 +130,28 @@ class Transcript(BaseModel):
class TranscriptController:
async def get_all(self) -> list[Transcript]:
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
query = transcripts.select()
if user_id is not None:
query = query.where(transcripts.c.user_id == user_id)
print(query)
results = await database.fetch_all(query)
print(results)
return results
async def get_by_id(self, transcript_id: str) -> Transcript | None:
async def get_by_id(
self, transcript_id: str, user_id: str | None = None
) -> Transcript | None:
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query)
if not result:
return None
if user_id is not None and result["user_id"] != user_id:
return None
return Transcript(**result)
async def add(self, name: str):
transcript = Transcript(name=name)
async def add(self, name: str, user_id: str | None = None):
transcript = Transcript(name=name, user_id=user_id)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
@@ -155,10 +166,14 @@ class TranscriptController:
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(self, transcript_id: str) -> None:
async def remove_by_id(
self, transcript_id: str, user_id: str | None = None
) -> None:
transcript = await self.get_by_id(transcript_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
@@ -199,13 +214,19 @@ class DeletionStatus(BaseModel):
@router.get("/transcripts", response_model=Page[GetTranscript])
async def transcripts_list():
return paginate(await transcripts_controller.get_all())
async def transcripts_list(
user: auth.UserInfo = Depends(auth.current_user),
):
return paginate(await transcripts_controller.get_all(user_id=user["sub"]))
@router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create(info: CreateTranscript):
return await transcripts_controller.add(info.name)
async def transcripts_create(
info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
return await transcripts_controller.add(info.name, user_id=user_id)
# ==============================================================
@@ -214,16 +235,25 @@ async def transcripts_create(info: CreateTranscript):
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_get(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_update(transcript_id: str, info: UpdateTranscript):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_update(
transcript_id: str,
info: UpdateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = {}
@@ -236,11 +266,15 @@ async def transcript_update(transcript_id: str, info: UpdateTranscript):
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
async def transcript_delete(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_delete(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await transcripts_controller.remove_by_id(transcript.id)
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
return DeletionStatus(status="ok")