mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
server: implement user authentication (none by default)
This commit is contained in:
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_pagination import add_pagination
|
||||
from fastapi.routing import APIRoute
|
||||
import reflector.db # noqa
|
||||
import reflector.auth # noqa
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
from reflector.views.transcripts import router as transcripts_router
|
||||
from reflector.events import subscribers_startup, subscribers_shutdown
|
||||
|
||||
13
server/reflector/auth/__init__.py
Normal file
13
server/reflector/auth/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from reflector.settings import settings
|
||||
from reflector.logger import logger
|
||||
import importlib
|
||||
|
||||
logger.info(f"User authentication using {settings.AUTH_BACKEND}")
|
||||
module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}"
|
||||
auth_module = importlib.import_module(module_name)
|
||||
|
||||
UserInfo = auth_module.UserInfo
|
||||
AccessTokenInfo = auth_module.AccessTokenInfo
|
||||
authenticated = auth_module.authenticated
|
||||
current_user = auth_module.current_user
|
||||
current_user_optional = auth_module.current_user_optional
|
||||
25
server/reflector/auth/auth_fief.py
Normal file
25
server/reflector/auth/auth_fief.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo
|
||||
from fief_client.integrations.fastapi import FiefAuth
|
||||
from reflector.settings import settings
|
||||
|
||||
fief = FiefAsync(
|
||||
settings.AUTH_FIEF_URL,
|
||||
settings.AUTH_FIEF_CLIENT_ID,
|
||||
settings.AUTH_FIEF_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
scheme = OAuth2AuthorizationCodeBearer(
|
||||
f"{settings.AUTH_FIEF_URL}/authorize",
|
||||
f"{settings.AUTH_FIEF_URL}/api/token",
|
||||
scopes={"openid": "openid", "offline_access": "offline_access"},
|
||||
auto_error=False,
|
||||
)
|
||||
|
||||
auth = FiefAuth(fief, scheme)
|
||||
|
||||
UserInfo = FiefUserInfo
|
||||
AccessTokenInfo = FiefAccessTokenInfo
|
||||
authenticated = auth.authenticated()
|
||||
current_user = auth.current_user()
|
||||
current_user_optional = auth.current_user(optional=True)
|
||||
35
server/reflector/auth/auth_none.py
Normal file
35
server/reflector/auth/auth_none.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
sub: str
|
||||
|
||||
|
||||
class AccessTokenInfo(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
def authenticated(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
def _authenticated():
|
||||
return None
|
||||
|
||||
return _authenticated
|
||||
|
||||
|
||||
def current_user(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
def _current_user():
|
||||
return None
|
||||
|
||||
return _current_user
|
||||
|
||||
|
||||
def current_user_optional(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
def _current_user_optional():
|
||||
return None
|
||||
|
||||
return _current_user_optional
|
||||
@@ -79,5 +79,13 @@ class Settings(BaseSettings):
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
# User authentication (none, fief)
|
||||
AUTH_BACKEND: str = "none"
|
||||
|
||||
# User authentication using fief
|
||||
AUTH_FIEF_URL: str | None = None
|
||||
AUTH_FIEF_CLIENT_ID: str | None = None
|
||||
AUTH_FIEF_CLIENT_SECRET: str | None = None
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import (
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
Depends,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
@@ -14,8 +15,9 @@ from fastapi_pagination import Page, paginate
|
||||
from reflector.logger import logger
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.settings import settings
|
||||
import reflector.auth as auth
|
||||
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
import av
|
||||
@@ -60,6 +62,7 @@ class TranscriptEvent(BaseModel):
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
@@ -127,20 +130,28 @@ class Transcript(BaseModel):
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self) -> list[Transcript]:
|
||||
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
||||
query = transcripts.select()
|
||||
if user_id is not None:
|
||||
query = query.where(transcripts.c.user_id == user_id)
|
||||
print(query)
|
||||
results = await database.fetch_all(query)
|
||||
print(results)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str) -> Transcript | None:
|
||||
async def get_by_id(
|
||||
self, transcript_id: str, user_id: str | None = None
|
||||
) -> Transcript | None:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
if user_id is not None and result["user_id"] != user_id:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def add(self, name: str):
|
||||
transcript = Transcript(name=name)
|
||||
async def add(self, name: str, user_id: str | None = None):
|
||||
transcript = Transcript(name=name, user_id=user_id)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
@@ -155,10 +166,14 @@ class TranscriptController:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(self, transcript_id: str) -> None:
|
||||
async def remove_by_id(
|
||||
self, transcript_id: str, user_id: str | None = None
|
||||
) -> None:
|
||||
transcript = await self.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
@@ -199,13 +214,19 @@ class DeletionStatus(BaseModel):
|
||||
|
||||
|
||||
@router.get("/transcripts", response_model=Page[GetTranscript])
|
||||
async def transcripts_list():
|
||||
return paginate(await transcripts_controller.get_all())
|
||||
async def transcripts_list(
|
||||
user: auth.UserInfo = Depends(auth.current_user),
|
||||
):
|
||||
return paginate(await transcripts_controller.get_all(user_id=user["sub"]))
|
||||
|
||||
|
||||
@router.post("/transcripts", response_model=GetTranscript)
|
||||
async def transcripts_create(info: CreateTranscript):
|
||||
return await transcripts_controller.add(info.name)
|
||||
async def transcripts_create(
|
||||
info: CreateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
return await transcripts_controller.add(info.name, user_id=user_id)
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -214,16 +235,25 @@ async def transcripts_create(info: CreateTranscript):
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_get(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_get(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_update(transcript_id: str, info: UpdateTranscript):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_update(
|
||||
transcript_id: str,
|
||||
info: UpdateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = {}
|
||||
@@ -236,11 +266,15 @@ async def transcript_update(transcript_id: str, info: UpdateTranscript):
|
||||
|
||||
|
||||
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
|
||||
async def transcript_delete(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_delete(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
await transcripts_controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user