server: implement user authentication (none by default)

This commit is contained in:
2023-08-16 17:24:05 +02:00
parent d470696a49
commit e12f9afe7b
13 changed files with 263 additions and 21 deletions

View File

@@ -11,6 +11,20 @@
#DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector #DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector
## =======================================================
## User authentication
## =======================================================
## No authentication
#AUTH_BACKEND=none
## Using fief (fief.dev)
#AUTH_BACKEND=fief
#AUTH_FIEF_URL=https://your-fief-instance....
#AUTH_FIEF_CLIENT_ID=xxx
#AUTH_FIEF_CLIENT_SECRET=xxx
## ======================================================= ## =======================================================
## Transcription backend ## Transcription backend
## ##

66
server/poetry.lock generated
View File

@@ -930,6 +930,23 @@ mysql = ["aiomysql"]
postgresql = ["asyncpg"] postgresql = ["asyncpg"]
sqlite = ["aiosqlite"] sqlite = ["aiosqlite"]
[[package]]
name = "deprecated"
version = "1.2.14"
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
{file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"},
]
[package.dependencies]
wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]] [[package]]
name = "dnspython" name = "dnspython"
version = "2.4.1" version = "2.4.1"
@@ -1022,6 +1039,28 @@ tokenizers = "==0.13.*"
conversion = ["transformers[torch] (>=4.23)"] conversion = ["transformers[torch] (>=4.23)"]
dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"] dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"]
[[package]]
name = "fief-client"
version = "0.17.0"
description = "Fief Client for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "fief_client-0.17.0-py3-none-any.whl", hash = "sha256:ecc8674ecaf58fc7d2926f5a0f49fabd3a1a03e278f030977a97ecb716b8884d"},
{file = "fief_client-0.17.0.tar.gz", hash = "sha256:f1f9a10c760c29811a8cce2c1d58938090901772826dda973b67dde1bce3bafd"},
]
[package.dependencies]
fastapi = {version = "*", optional = true, markers = "extra == \"fastapi\""}
httpx = ">=0.21.3,<0.25.0"
jwcrypto = ">=1.4,<2.0.0"
makefun = {version = ">=1.14.0,<2.0.0", optional = true, markers = "extra == \"fastapi\""}
[package.extras]
cli = ["halo"]
fastapi = ["fastapi", "makefun (>=1.14.0,<2.0.0)"]
flask = ["flask"]
[[package]] [[package]]
name = "filelock" name = "filelock"
version = "3.12.2" version = "3.12.2"
@@ -1530,6 +1569,20 @@ files = [
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
] ]
[[package]]
name = "jwcrypto"
version = "1.5.0"
description = "Implementation of JOSE Web standards"
optional = false
python-versions = ">= 3.6"
files = [
{file = "jwcrypto-1.5.0.tar.gz", hash = "sha256:2c1dc51cf8e38ddf324795dfe9426dee9dd46caf47f535ccbc18781fba810b8d"},
]
[package.dependencies]
cryptography = ">=3.4"
deprecated = "*"
[[package]] [[package]]
name = "levenshtein" name = "levenshtein"
version = "0.21.1" version = "0.21.1"
@@ -1662,6 +1715,17 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras] [package.extras]
dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"] dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"]
[[package]]
name = "makefun"
version = "1.15.1"
description = "Small library to dynamically create python functions."
optional = false
python-versions = "*"
files = [
{file = "makefun-1.15.1-py2.py3-none-any.whl", hash = "sha256:a63cfc7b47a539c76d97bd4fdb833c7d0461e759fd1225f580cb4be6200294d4"},
{file = "makefun-1.15.1.tar.gz", hash = "sha256:40b0f118b6ded0d8d78c78f1eb679b8b6b2462e3c1b3e05fb1b2da8cd46b48a5"},
]
[[package]] [[package]]
name = "mpmath" name = "mpmath"
version = "1.3.0" version = "1.3.0"
@@ -3234,4 +3298,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "ea523f9b74581a7867097a6249d416d8836f4daaf33fde65ea343e4d3502c71c" content-hash = "d84edfea8ac7a849340af8eb5db47df9c13a7cc1c640062ebedb2a808be0de4e"

View File

@@ -25,6 +25,7 @@ httpx = "^0.24.1"
fastapi-pagination = "^0.12.6" fastapi-pagination = "^0.12.6"
databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"} databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"}
sqlalchemy = "<1.5" sqlalchemy = "<1.5"
fief-client = {extras = ["fastapi"], version = "^0.17.0"}
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]

View File

@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination from fastapi_pagination import add_pagination
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
import reflector.db # noqa import reflector.db # noqa
import reflector.auth # noqa
from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router from reflector.views.transcripts import router as transcripts_router
from reflector.events import subscribers_startup, subscribers_shutdown from reflector.events import subscribers_startup, subscribers_shutdown

View File

@@ -0,0 +1,13 @@
from reflector.settings import settings
from reflector.logger import logger
import importlib
logger.info(f"User authentication using {settings.AUTH_BACKEND}")
module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}"
auth_module = importlib.import_module(module_name)
UserInfo = auth_module.UserInfo
AccessTokenInfo = auth_module.AccessTokenInfo
authenticated = auth_module.authenticated
current_user = auth_module.current_user
current_user_optional = auth_module.current_user_optional

View File

@@ -0,0 +1,25 @@
from fastapi.security import OAuth2AuthorizationCodeBearer
from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo
from fief_client.integrations.fastapi import FiefAuth
from reflector.settings import settings
fief = FiefAsync(
settings.AUTH_FIEF_URL,
settings.AUTH_FIEF_CLIENT_ID,
settings.AUTH_FIEF_CLIENT_SECRET,
)
scheme = OAuth2AuthorizationCodeBearer(
f"{settings.AUTH_FIEF_URL}/authorize",
f"{settings.AUTH_FIEF_URL}/api/token",
scopes={"openid": "openid", "offline_access": "offline_access"},
auto_error=False,
)
auth = FiefAuth(fief, scheme)
UserInfo = FiefUserInfo
AccessTokenInfo = FiefAccessTokenInfo
authenticated = auth.authenticated()
current_user = auth.current_user()
current_user_optional = auth.current_user(optional=True)

View File

@@ -0,0 +1,35 @@
from pydantic import BaseModel
from typing import Annotated
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class UserInfo(BaseModel):
sub: str
class AccessTokenInfo(BaseModel):
pass
def authenticated(token: Annotated[str, Depends(oauth2_scheme)]):
def _authenticated():
return None
return _authenticated
def current_user(token: Annotated[str, Depends(oauth2_scheme)]):
def _current_user():
return None
return _current_user
def current_user_optional(token: Annotated[str, Depends(oauth2_scheme)]):
def _current_user_optional():
return None
return _current_user_optional

View File

@@ -79,5 +79,13 @@ class Settings(BaseSettings):
# Sentry # Sentry
SENTRY_DSN: str | None = None SENTRY_DSN: str | None = None
# User authentication (none, fief)
AUTH_BACKEND: str = "none"
# User authentication using fief
AUTH_FIEF_URL: str | None = None
AUTH_FIEF_CLIENT_ID: str | None = None
AUTH_FIEF_CLIENT_SECRET: str | None = None
settings = Settings() settings = Settings()

View File

@@ -4,6 +4,7 @@ from fastapi import (
Request, Request,
WebSocket, WebSocket,
WebSocketDisconnect, WebSocketDisconnect,
Depends,
) )
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
@@ -14,8 +15,9 @@ from fastapi_pagination import Page, paginate
from reflector.logger import logger from reflector.logger import logger
from reflector.db import database, transcripts from reflector.db import database, transcripts
from reflector.settings import settings from reflector.settings import settings
import reflector.auth as auth
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
from typing import Optional from typing import Annotated, Optional
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import av import av
@@ -60,6 +62,7 @@ class TranscriptEvent(BaseModel):
class Transcript(BaseModel): class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name) name: str = Field(default_factory=generate_transcript_name)
status: str = "idle" status: str = "idle"
locked: bool = False locked: bool = False
@@ -127,20 +130,28 @@ class Transcript(BaseModel):
class TranscriptController: class TranscriptController:
async def get_all(self) -> list[Transcript]: async def get_all(self, user_id: str | None = None) -> list[Transcript]:
query = transcripts.select() query = transcripts.select()
if user_id is not None:
query = query.where(transcripts.c.user_id == user_id)
print(query)
results = await database.fetch_all(query) results = await database.fetch_all(query)
print(results)
return results return results
async def get_by_id(self, transcript_id: str) -> Transcript | None: async def get_by_id(
self, transcript_id: str, user_id: str | None = None
) -> Transcript | None:
query = transcripts.select().where(transcripts.c.id == transcript_id) query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query) result = await database.fetch_one(query)
if not result: if not result:
return None return None
if user_id is not None and result["user_id"] != user_id:
return None
return Transcript(**result) return Transcript(**result)
async def add(self, name: str): async def add(self, name: str, user_id: str | None = None):
transcript = Transcript(name=name) transcript = Transcript(name=name, user_id=user_id)
query = transcripts.insert().values(**transcript.model_dump()) query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query) await database.execute(query)
return transcript return transcript
@@ -155,10 +166,14 @@ class TranscriptController:
for key, value in values.items(): for key, value in values.items():
setattr(transcript, key, value) setattr(transcript, key, value)
async def remove_by_id(self, transcript_id: str) -> None: async def remove_by_id(
self, transcript_id: str, user_id: str | None = None
) -> None:
transcript = await self.get_by_id(transcript_id) transcript = await self.get_by_id(transcript_id)
if not transcript: if not transcript:
return return
if user_id is not None and transcript.user_id != user_id:
return
transcript.unlink() transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id) query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query) await database.execute(query)
@@ -199,13 +214,19 @@ class DeletionStatus(BaseModel):
@router.get("/transcripts", response_model=Page[GetTranscript]) @router.get("/transcripts", response_model=Page[GetTranscript])
async def transcripts_list(): async def transcripts_list(
return paginate(await transcripts_controller.get_all()) user: auth.UserInfo = Depends(auth.current_user),
):
return paginate(await transcripts_controller.get_all(user_id=user["sub"]))
@router.post("/transcripts", response_model=GetTranscript) @router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create(info: CreateTranscript): async def transcripts_create(
return await transcripts_controller.add(info.name) info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
return await transcripts_controller.add(info.name, user_id=user_id)
# ============================================================== # ==============================================================
@@ -214,16 +235,25 @@ async def transcripts_create(info: CreateTranscript):
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript) @router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(transcript_id: str): async def transcript_get(
transcript = await transcripts_controller.get_by_id(transcript_id) transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
return transcript return transcript
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript) @router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_update(transcript_id: str, info: UpdateTranscript): async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id) transcript_id: str,
info: UpdateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
values = {} values = {}
@@ -236,11 +266,15 @@ async def transcript_update(transcript_id: str, info: UpdateTranscript):
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus) @router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
async def transcript_delete(transcript_id: str): async def transcript_delete(
transcript = await transcripts_controller.get_by_id(transcript_id) transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
await transcripts_controller.remove_by_id(transcript.id) await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")

Binary file not shown.

16
server/tests/conftest.py Normal file
View File

@@ -0,0 +1,16 @@
import pytest
@pytest.fixture(scope="function", autouse=True)
@pytest.mark.asyncio
async def setup_database():
from reflector.settings import settings
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}"
from reflector.db import engine, metadata
metadata.create_all(bind=engine)
yield

View File

@@ -1,10 +1,11 @@
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from reflector.app import app
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_create(): async def test_transcript_create():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac: async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"}) response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
@@ -21,6 +22,8 @@ async def test_transcript_create():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_get_update_name(): async def test_transcript_get_update_name():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac: async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"}) response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
@@ -42,9 +45,35 @@ async def test_transcript_get_update_name():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcripts_list(): async def test_transcripts_list_anonymous():
# XXX this test is a bit fragile, as it depends on the storage which # XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests # is shared between tests
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.get("/transcripts")
assert response.status_code == 401
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {"sub": "randomuserid"}
app.dependency_overrides[current_user_optional] = lambda: {"sub": "randomuserid"}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@pytest.mark.asyncio
async def test_transcripts_list_authenticated(authenticated_client):
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac: async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testxx1"}) response = await ac.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200 assert response.status_code == 200
@@ -64,6 +93,8 @@ async def test_transcripts_list():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_delete(): async def test_transcript_delete():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac: async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testdel1"}) response = await ac.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -8,7 +8,6 @@ import json
from unittest.mock import patch from unittest.mock import patch
from httpx import AsyncClient from httpx import AsyncClient
from reflector.app import app
from uvicorn import Config, Server from uvicorn import Config, Server
import threading import threading
import asyncio import asyncio
@@ -76,6 +75,7 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
# to be able to connect with aiortc # to be able to connect with aiortc
from reflector.settings import settings from reflector.settings import settings
from reflector.app import app
settings.DATA_DIR = Path(tmpdir) settings.DATA_DIR = Path(tmpdir)