From e12f9afe7b4283a7f11e95d39845d406ef14eb43 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 16 Aug 2023 17:24:05 +0200 Subject: [PATCH] server: implement user authentication (none by default) --- server/env.example | 14 +++++ server/poetry.lock | 66 ++++++++++++++++++++++- server/pyproject.toml | 1 + server/reflector/app.py | 1 + server/reflector/auth/__init__.py | 13 +++++ server/reflector/auth/auth_fief.py | 25 +++++++++ server/reflector/auth/auth_none.py | 35 ++++++++++++ server/reflector/settings.py | 8 +++ server/reflector/views/transcripts.py | 68 ++++++++++++++++++------ server/test.db | Bin 20480 -> 0 bytes server/tests/conftest.py | 16 ++++++ server/tests/test_transcripts.py | 35 +++++++++++- server/tests/test_transcripts_rtc_ws.py | 2 +- 13 files changed, 263 insertions(+), 21 deletions(-) create mode 100644 server/reflector/auth/__init__.py create mode 100644 server/reflector/auth/auth_fief.py create mode 100644 server/reflector/auth/auth_none.py delete mode 100644 server/test.db create mode 100644 server/tests/conftest.py diff --git a/server/env.example b/server/env.example index 11e0927b..a8fd4128 100644 --- a/server/env.example +++ b/server/env.example @@ -11,6 +11,20 @@ #DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector +## ======================================================= +## User authentication +## ======================================================= + +## No authentication +#AUTH_BACKEND=none + +## Using fief (fief.dev) +#AUTH_BACKEND=fief +#AUTH_FIEF_URL=https://your-fief-instance.... +#AUTH_FIEF_CLIENT_ID=xxx +#AUTH_FIEF_CLIENT_SECRET=xxx + + ## ======================================================= ## Transcription backend ## diff --git a/server/poetry.lock b/server/poetry.lock index 9ad03bcf..b3b122f4 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -930,6 +930,23 @@ mysql = ["aiomysql"] postgresql = ["asyncpg"] sqlite = ["aiosqlite"] +[[package]] +name = "deprecated" +version = "1.2.14" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, + {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"}, +] + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] + [[package]] name = "dnspython" version = "2.4.1" @@ -1022,6 +1039,28 @@ tokenizers = "==0.13.*" conversion = ["transformers[torch] (>=4.23)"] dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"] +[[package]] +name = "fief-client" +version = "0.17.0" +description = "Fief Client for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fief_client-0.17.0-py3-none-any.whl", hash = "sha256:ecc8674ecaf58fc7d2926f5a0f49fabd3a1a03e278f030977a97ecb716b8884d"}, + {file = "fief_client-0.17.0.tar.gz", hash = "sha256:f1f9a10c760c29811a8cce2c1d58938090901772826dda973b67dde1bce3bafd"}, +] + +[package.dependencies] +fastapi = {version = "*", optional = true, markers = "extra == \"fastapi\""} +httpx = ">=0.21.3,<0.25.0" +jwcrypto = ">=1.4,<2.0.0" +makefun = {version = ">=1.14.0,<2.0.0", optional = true, markers = "extra == \"fastapi\""} + +[package.extras] +cli = ["halo"] +fastapi = ["fastapi", "makefun (>=1.14.0,<2.0.0)"] +flask = ["flask"] + [[package]] name = "filelock" version = "3.12.2" @@ -1530,6 +1569,20 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "jwcrypto" +version = "1.5.0" +description = "Implementation of JOSE Web standards" +optional = false +python-versions = ">= 3.6" +files = [ + {file = "jwcrypto-1.5.0.tar.gz", hash = "sha256:2c1dc51cf8e38ddf324795dfe9426dee9dd46caf47f535ccbc18781fba810b8d"}, +] + +[package.dependencies] +cryptography = ">=3.4" +deprecated = "*" + [[package]] name = "levenshtein" version = "0.21.1" @@ -1662,6 +1715,17 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"] +[[package]] +name = "makefun" +version = "1.15.1" +description = "Small library to dynamically create python functions." +optional = false +python-versions = "*" +files = [ + {file = "makefun-1.15.1-py2.py3-none-any.whl", hash = "sha256:a63cfc7b47a539c76d97bd4fdb833c7d0461e759fd1225f580cb4be6200294d4"}, + {file = "makefun-1.15.1.tar.gz", hash = "sha256:40b0f118b6ded0d8d78c78f1eb679b8b6b2462e3c1b3e05fb1b2da8cd46b48a5"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -3234,4 +3298,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "ea523f9b74581a7867097a6249d416d8836f4daaf33fde65ea343e4d3502c71c" +content-hash = "d84edfea8ac7a849340af8eb5db47df9c13a7cc1c640062ebedb2a808be0de4e" diff --git a/server/pyproject.toml b/server/pyproject.toml index e3e75843..895be79d 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -25,6 +25,7 @@ httpx = "^0.24.1" fastapi-pagination = "^0.12.6" databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"} sqlalchemy = "<1.5" +fief-client = {extras = ["fastapi"], version = "^0.17.0"} [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/app.py b/server/reflector/app.py index 8383bf32..fa148240 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi_pagination import add_pagination from fastapi.routing import APIRoute import reflector.db # noqa +import reflector.auth # noqa from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.transcripts import router as transcripts_router from reflector.events import subscribers_startup, subscribers_shutdown diff --git a/server/reflector/auth/__init__.py b/server/reflector/auth/__init__.py new file mode 100644 index 00000000..65e75d9b --- /dev/null +++ b/server/reflector/auth/__init__.py @@ -0,0 +1,13 @@ +from reflector.settings import settings +from reflector.logger import logger +import importlib + +logger.info(f"User authentication using {settings.AUTH_BACKEND}") +module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}" +auth_module = importlib.import_module(module_name) + +UserInfo = auth_module.UserInfo +AccessTokenInfo = auth_module.AccessTokenInfo +authenticated = auth_module.authenticated +current_user = auth_module.current_user +current_user_optional = auth_module.current_user_optional diff --git a/server/reflector/auth/auth_fief.py b/server/reflector/auth/auth_fief.py new file mode 100644 index 00000000..0b363fc0 --- /dev/null +++ b/server/reflector/auth/auth_fief.py @@ -0,0 +1,25 @@ +from fastapi.security import OAuth2AuthorizationCodeBearer +from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo +from fief_client.integrations.fastapi import FiefAuth +from reflector.settings import settings + +fief = FiefAsync( + settings.AUTH_FIEF_URL, + settings.AUTH_FIEF_CLIENT_ID, + settings.AUTH_FIEF_CLIENT_SECRET, +) + +scheme = OAuth2AuthorizationCodeBearer( + f"{settings.AUTH_FIEF_URL}/authorize", + f"{settings.AUTH_FIEF_URL}/api/token", + scopes={"openid": "openid", "offline_access": "offline_access"}, + auto_error=False, +) + +auth = FiefAuth(fief, scheme) + +UserInfo = FiefUserInfo +AccessTokenInfo = FiefAccessTokenInfo +authenticated = auth.authenticated() +current_user = auth.current_user() +current_user_optional = auth.current_user(optional=True) diff --git a/server/reflector/auth/auth_none.py b/server/reflector/auth/auth_none.py new file mode 100644 index 00000000..3959c739 --- /dev/null +++ b/server/reflector/auth/auth_none.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel +from typing import Annotated +from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +class UserInfo(BaseModel): + sub: str + + +class AccessTokenInfo(BaseModel): + pass + + +def authenticated(token: Annotated[str, Depends(oauth2_scheme)]): + def _authenticated(): + return None + + return _authenticated + + +def current_user(token: Annotated[str, Depends(oauth2_scheme)]): + def _current_user(): + return None + + return _current_user + + +def current_user_optional(token: Annotated[str, Depends(oauth2_scheme)]): + def _current_user_optional(): + return None + + return _current_user_optional diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 0787b466..2add7448 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -79,5 +79,13 @@ class Settings(BaseSettings): # Sentry SENTRY_DSN: str | None = None + # User authentication (none, fief) + AUTH_BACKEND: str = "none" + + # User authentication using fief + AUTH_FIEF_URL: str | None = None + AUTH_FIEF_CLIENT_ID: str | None = None + AUTH_FIEF_CLIENT_SECRET: str | None = None + settings = Settings() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 6f952938..778c47d7 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -4,6 +4,7 @@ from fastapi import ( Request, WebSocket, WebSocketDisconnect, + Depends, ) from fastapi.responses import FileResponse from starlette.concurrency import run_in_threadpool @@ -14,8 +15,9 @@ from fastapi_pagination import Page, paginate from reflector.logger import logger from reflector.db import database, transcripts from reflector.settings import settings +import reflector.auth as auth from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent -from typing import Optional +from typing import Annotated, Optional from pathlib import Path from tempfile import NamedTemporaryFile import av @@ -60,6 +62,7 @@ class TranscriptEvent(BaseModel): class Transcript(BaseModel): id: str = Field(default_factory=generate_uuid4) + user_id: str | None = None name: str = Field(default_factory=generate_transcript_name) status: str = "idle" locked: bool = False @@ -127,20 +130,28 @@ class Transcript(BaseModel): class TranscriptController: - async def get_all(self) -> list[Transcript]: + async def get_all(self, user_id: str | None = None) -> list[Transcript]: query = transcripts.select() + if user_id is not None: + query = query.where(transcripts.c.user_id == user_id) + print(query) results = await database.fetch_all(query) + print(results) return results - async def get_by_id(self, transcript_id: str) -> Transcript | None: + async def get_by_id( + self, transcript_id: str, user_id: str | None = None + ) -> Transcript | None: query = transcripts.select().where(transcripts.c.id == transcript_id) result = await database.fetch_one(query) if not result: return None + if user_id is not None and result["user_id"] != user_id: + return None return Transcript(**result) - async def add(self, name: str): - transcript = Transcript(name=name) + async def add(self, name: str, user_id: str | None = None): + transcript = Transcript(name=name, user_id=user_id) query = transcripts.insert().values(**transcript.model_dump()) await database.execute(query) return transcript @@ -155,10 +166,14 @@ class TranscriptController: for key, value in values.items(): setattr(transcript, key, value) - async def remove_by_id(self, transcript_id: str) -> None: + async def remove_by_id( + self, transcript_id: str, user_id: str | None = None + ) -> None: transcript = await self.get_by_id(transcript_id) if not transcript: return + if user_id is not None and transcript.user_id != user_id: + return transcript.unlink() query = transcripts.delete().where(transcripts.c.id == transcript_id) await database.execute(query) @@ -199,13 +214,19 @@ class DeletionStatus(BaseModel): @router.get("/transcripts", response_model=Page[GetTranscript]) -async def transcripts_list(): - return paginate(await transcripts_controller.get_all()) +async def transcripts_list( + user: auth.UserInfo = Depends(auth.current_user), +): + return paginate(await transcripts_controller.get_all(user_id=user["sub"])) @router.post("/transcripts", response_model=GetTranscript) -async def transcripts_create(info: CreateTranscript): - return await transcripts_controller.add(info.name) +async def transcripts_create( + info: CreateTranscript, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + return await transcripts_controller.add(info.name, user_id=user_id) # ============================================================== @@ -214,16 +235,25 @@ async def transcripts_create(info: CreateTranscript): @router.get("/transcripts/{transcript_id}", response_model=GetTranscript) -async def transcript_get(transcript_id: str): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_get( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") return transcript @router.patch("/transcripts/{transcript_id}", response_model=GetTranscript) -async def transcript_update(transcript_id: str, info: UpdateTranscript): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_update( + transcript_id: str, + info: UpdateTranscript, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") values = {} @@ -236,11 +266,15 @@ async def transcript_update(transcript_id: str, info: UpdateTranscript): @router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus) -async def transcript_delete(transcript_id: str): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_delete( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - await transcripts_controller.remove_by_id(transcript.id) + await transcripts_controller.remove_by_id(transcript.id, user_id=user_id) return DeletionStatus(status="ok") diff --git a/server/test.db b/server/test.db deleted file mode 100644 index 974a6a14e9005e55574c4cea2c0b7f5b04ae13d6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20480 zcmeI2OKerWq?jB#3YHy&qm6oM5#9Cjr+J)e2Gcy)H83(RwCT z5Qs&qvhEHEv8<5T01}%bv8hrEfNb>Jt%FTZiK6pQb~2h-!H zk0$Z(?y;{)^=lWOj&M?EI9Y_6T1|D; zOAn{gL%wAKnLs9x31kA9Kqin0WCEE$CXfka0-3B#?gp-&p!VYw(l7ZvS8X5Bt}9pY(pzyVCtz_k*tK{JHaf z=X32pw7=W_%-ZkP-dVe_`kU4Jt7lezwGvj^Cx3o&y!`LwA207L{R1TOEfdHDGJ#Kn z!2U*OeE#g}>Q?LY=@xOE7gj0cyt9Z)rjgZ{BcoWbP6$Pq*^86uaCh(OLAVzwrVL?? zh!_&SN#!ObBdn;_{PvyOcitR!#^;Wtu~J@$AW>nhMw}^&v|Jkm;mN~wIIrmO@Z!hk5_HXt6HLRJ!tY>YYzLM0EnQYQ12sUW5Dtyeqa^&_7( zmW*QUkQJmr+(?UzGZay(gTyg<;$%{u?=k0`(d6;?d}n<2NS9mVuwf2GY6r93Nol`_1!DwS`&REwU)P6B<87_lC43QQ~oI>r;Pm3M`p zbE8-4Dr1CTW%ORGMbBvBA_k8G=Mi*+Lt1--xG+Rp5lj){oaj}uiff`v#OpN>M1V^i zJH(^`5QPC*R|KRKp;k*ve3fjo3xsSEGtxpDDof=TYRZ&Gp{yr}D+z8!Cy^mVL`*2g zA>bm!Vm{(0%8;*RBHub&L#zv{nKuQpOaTfKpq^t0sEE!LMU=c?rShz#sB{^o+?0r) zuOYUA1;P|~TxfW3fY@OikidjTLlsnEnQk)*q>?3s@-mBWovFLaioi5<3IfT6fOV=n zERluwCYhztbCN*JahHuS<3d#KK38)WdmRX2khvTKU?-vQw6};wE-Yu>hV-AlGwT^| z3OXW`LGF|kPS<3uaF{US5Wy7OHBf#=Q;9&@2ZABFjhQV=13_0jgNH`I}YbiW_GVJF2Or_za&*M&wa)E7BFxJ3H4n9W{pr1iv2nZWDE(k8m zVfVk4Q8~;p)3TCn*JM2un}MM;LVQGR8 z5v3prbxOC{!7G)S%iZ6}(Cr z8WpK3?_`ZKbT6uuFeo?)lmRTXH6@^`8MahP=lQ%6DP^G>m?zzGxh5-ARKP-^U=6e% z4gJ_@NDR&jD9C6q^(EyQIZ?@Cf^pd`$4fQYXn5gh(8zlWGaphJS?>tKm|7U0)F@nb z%h|GIBnZ5F)zubjvV{W*$y*&fKnPwy)p@v#Hf&kGsZ#Abvg#=7)2{r%^?Wjf<s zq^~5yCQHwA$*{=|^W|jNWF`4h`PL0~j4vj`CR@YxWY}a0_(C#lveP@844W+Jo-Y*} ztliEe!zO#Q=c>1U`ey8OGHkNzI+YBYY_A5%u*p8EpA4I9mwKg-4VFdSWY}c)(@BO+ zwmI!&*knJmmJFM0TUL`{ljX@uGHkL7IhhQbY&Vum9~