server: implement user authentication (none by default)

This commit is contained in:
2023-08-16 17:24:05 +02:00
parent d470696a49
commit e12f9afe7b
13 changed files with 263 additions and 21 deletions

View File

@@ -11,6 +11,20 @@
#DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector
## =======================================================
## User authentication
## =======================================================
## No authentication
#AUTH_BACKEND=none
## Using fief (fief.dev)
#AUTH_BACKEND=fief
#AUTH_FIEF_URL=https://your-fief-instance....
#AUTH_FIEF_CLIENT_ID=xxx
#AUTH_FIEF_CLIENT_SECRET=xxx
## =======================================================
## Transcription backend
##

66
server/poetry.lock generated
View File

@@ -930,6 +930,23 @@ mysql = ["aiomysql"]
postgresql = ["asyncpg"]
sqlite = ["aiosqlite"]
[[package]]
name = "deprecated"
version = "1.2.14"
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
{file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"},
]
[package.dependencies]
wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "dnspython"
version = "2.4.1"
@@ -1022,6 +1039,28 @@ tokenizers = "==0.13.*"
conversion = ["transformers[torch] (>=4.23)"]
dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"]
[[package]]
name = "fief-client"
version = "0.17.0"
description = "Fief Client for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "fief_client-0.17.0-py3-none-any.whl", hash = "sha256:ecc8674ecaf58fc7d2926f5a0f49fabd3a1a03e278f030977a97ecb716b8884d"},
{file = "fief_client-0.17.0.tar.gz", hash = "sha256:f1f9a10c760c29811a8cce2c1d58938090901772826dda973b67dde1bce3bafd"},
]
[package.dependencies]
fastapi = {version = "*", optional = true, markers = "extra == \"fastapi\""}
httpx = ">=0.21.3,<0.25.0"
jwcrypto = ">=1.4,<2.0.0"
makefun = {version = ">=1.14.0,<2.0.0", optional = true, markers = "extra == \"fastapi\""}
[package.extras]
cli = ["halo"]
fastapi = ["fastapi", "makefun (>=1.14.0,<2.0.0)"]
flask = ["flask"]
[[package]]
name = "filelock"
version = "3.12.2"
@@ -1530,6 +1569,20 @@ files = [
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
]
[[package]]
name = "jwcrypto"
version = "1.5.0"
description = "Implementation of JOSE Web standards"
optional = false
python-versions = ">= 3.6"
files = [
{file = "jwcrypto-1.5.0.tar.gz", hash = "sha256:2c1dc51cf8e38ddf324795dfe9426dee9dd46caf47f535ccbc18781fba810b8d"},
]
[package.dependencies]
cryptography = ">=3.4"
deprecated = "*"
[[package]]
name = "levenshtein"
version = "0.21.1"
@@ -1662,6 +1715,17 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras]
dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"]
[[package]]
name = "makefun"
version = "1.15.1"
description = "Small library to dynamically create python functions."
optional = false
python-versions = "*"
files = [
{file = "makefun-1.15.1-py2.py3-none-any.whl", hash = "sha256:a63cfc7b47a539c76d97bd4fdb833c7d0461e759fd1225f580cb4be6200294d4"},
{file = "makefun-1.15.1.tar.gz", hash = "sha256:40b0f118b6ded0d8d78c78f1eb679b8b6b2462e3c1b3e05fb1b2da8cd46b48a5"},
]
[[package]]
name = "mpmath"
version = "1.3.0"
@@ -3234,4 +3298,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "ea523f9b74581a7867097a6249d416d8836f4daaf33fde65ea343e4d3502c71c"
content-hash = "d84edfea8ac7a849340af8eb5db47df9c13a7cc1c640062ebedb2a808be0de4e"

View File

@@ -25,6 +25,7 @@ httpx = "^0.24.1"
fastapi-pagination = "^0.12.6"
databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"}
sqlalchemy = "<1.5"
fief-client = {extras = ["fastapi"], version = "^0.17.0"}
[tool.poetry.group.dev.dependencies]

View File

@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination
from fastapi.routing import APIRoute
import reflector.db # noqa
import reflector.auth # noqa
from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
from reflector.events import subscribers_startup, subscribers_shutdown

View File

@@ -0,0 +1,13 @@
from reflector.settings import settings
from reflector.logger import logger
import importlib
logger.info(f"User authentication using {settings.AUTH_BACKEND}")
module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}"
auth_module = importlib.import_module(module_name)
UserInfo = auth_module.UserInfo
AccessTokenInfo = auth_module.AccessTokenInfo
authenticated = auth_module.authenticated
current_user = auth_module.current_user
current_user_optional = auth_module.current_user_optional

View File

@@ -0,0 +1,25 @@
from fastapi.security import OAuth2AuthorizationCodeBearer
from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo
from fief_client.integrations.fastapi import FiefAuth
from reflector.settings import settings
fief = FiefAsync(
settings.AUTH_FIEF_URL,
settings.AUTH_FIEF_CLIENT_ID,
settings.AUTH_FIEF_CLIENT_SECRET,
)
scheme = OAuth2AuthorizationCodeBearer(
f"{settings.AUTH_FIEF_URL}/authorize",
f"{settings.AUTH_FIEF_URL}/api/token",
scopes={"openid": "openid", "offline_access": "offline_access"},
auto_error=False,
)
auth = FiefAuth(fief, scheme)
UserInfo = FiefUserInfo
AccessTokenInfo = FiefAccessTokenInfo
authenticated = auth.authenticated()
current_user = auth.current_user()
current_user_optional = auth.current_user(optional=True)

View File

@@ -0,0 +1,35 @@
from pydantic import BaseModel
from typing import Annotated
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class UserInfo(BaseModel):
sub: str
class AccessTokenInfo(BaseModel):
pass
def authenticated(token: Annotated[str, Depends(oauth2_scheme)]):
def _authenticated():
return None
return _authenticated
def current_user(token: Annotated[str, Depends(oauth2_scheme)]):
def _current_user():
return None
return _current_user
def current_user_optional(token: Annotated[str, Depends(oauth2_scheme)]):
def _current_user_optional():
return None
return _current_user_optional

View File

@@ -79,5 +79,13 @@ class Settings(BaseSettings):
# Sentry
SENTRY_DSN: str | None = None
# User authentication (none, fief)
AUTH_BACKEND: str = "none"
# User authentication using fief
AUTH_FIEF_URL: str | None = None
AUTH_FIEF_CLIENT_ID: str | None = None
AUTH_FIEF_CLIENT_SECRET: str | None = None
settings = Settings()

View File

@@ -4,6 +4,7 @@ from fastapi import (
Request,
WebSocket,
WebSocketDisconnect,
Depends,
)
from fastapi.responses import FileResponse
from starlette.concurrency import run_in_threadpool
@@ -14,8 +15,9 @@ from fastapi_pagination import Page, paginate
from reflector.logger import logger
from reflector.db import database, transcripts
from reflector.settings import settings
import reflector.auth as auth
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
from typing import Optional
from typing import Annotated, Optional
from pathlib import Path
from tempfile import NamedTemporaryFile
import av
@@ -60,6 +62,7 @@ class TranscriptEvent(BaseModel):
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
@@ -127,20 +130,28 @@ class Transcript(BaseModel):
class TranscriptController:
async def get_all(self) -> list[Transcript]:
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
query = transcripts.select()
if user_id is not None:
query = query.where(transcripts.c.user_id == user_id)
print(query)
results = await database.fetch_all(query)
print(results)
return results
async def get_by_id(self, transcript_id: str) -> Transcript | None:
async def get_by_id(
self, transcript_id: str, user_id: str | None = None
) -> Transcript | None:
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query)
if not result:
return None
if user_id is not None and result["user_id"] != user_id:
return None
return Transcript(**result)
async def add(self, name: str):
transcript = Transcript(name=name)
async def add(self, name: str, user_id: str | None = None):
transcript = Transcript(name=name, user_id=user_id)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
@@ -155,10 +166,14 @@ class TranscriptController:
for key, value in values.items():
setattr(transcript, key, value)
async def remove_by_id(self, transcript_id: str) -> None:
async def remove_by_id(
self, transcript_id: str, user_id: str | None = None
) -> None:
transcript = await self.get_by_id(transcript_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
@@ -199,13 +214,19 @@ class DeletionStatus(BaseModel):
@router.get("/transcripts", response_model=Page[GetTranscript])
async def transcripts_list():
return paginate(await transcripts_controller.get_all())
async def transcripts_list(
user: auth.UserInfo = Depends(auth.current_user),
):
return paginate(await transcripts_controller.get_all(user_id=user["sub"]))
@router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create(info: CreateTranscript):
return await transcripts_controller.add(info.name)
async def transcripts_create(
info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
return await transcripts_controller.add(info.name, user_id=user_id)
# ==============================================================
@@ -214,16 +235,25 @@ async def transcripts_create(info: CreateTranscript):
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_get(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_update(transcript_id: str, info: UpdateTranscript):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_update(
transcript_id: str,
info: UpdateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = {}
@@ -236,11 +266,15 @@ async def transcript_update(transcript_id: str, info: UpdateTranscript):
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
async def transcript_delete(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_delete(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await transcripts_controller.remove_by_id(transcript.id)
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
return DeletionStatus(status="ok")

Binary file not shown.

16
server/tests/conftest.py Normal file
View File

@@ -0,0 +1,16 @@
import pytest
@pytest.fixture(scope="function", autouse=True)
@pytest.mark.asyncio
async def setup_database():
from reflector.settings import settings
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}"
from reflector.db import engine, metadata
metadata.create_all(bind=engine)
yield

View File

@@ -1,10 +1,11 @@
import pytest
from httpx import AsyncClient
from reflector.app import app
@pytest.mark.asyncio
async def test_transcript_create():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
@@ -21,6 +22,8 @@ async def test_transcript_create():
@pytest.mark.asyncio
async def test_transcript_get_update_name():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
@@ -42,9 +45,35 @@ async def test_transcript_get_update_name():
@pytest.mark.asyncio
async def test_transcripts_list():
async def test_transcripts_list_anonymous():
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.get("/transcripts")
assert response.status_code == 401
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {"sub": "randomuserid"}
app.dependency_overrides[current_user_optional] = lambda: {"sub": "randomuserid"}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@pytest.mark.asyncio
async def test_transcripts_list_authenticated(authenticated_client):
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200
@@ -64,6 +93,8 @@ async def test_transcripts_list():
@pytest.mark.asyncio
async def test_transcript_delete():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200

View File

@@ -8,7 +8,6 @@ import json
from unittest.mock import patch
from httpx import AsyncClient
from reflector.app import app
from uvicorn import Config, Server
import threading
import asyncio
@@ -76,6 +75,7 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
# to be able to connect with aiortc
from reflector.settings import settings
from reflector.app import app
settings.DATA_DIR = Path(tmpdir)