diff --git a/server/env.example b/server/env.example index 11e0927b..a8fd4128 100644 --- a/server/env.example +++ b/server/env.example @@ -11,6 +11,20 @@ #DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector +## ======================================================= +## User authentication +## ======================================================= + +## No authentication +#AUTH_BACKEND=none + +## Using fief (fief.dev) +#AUTH_BACKEND=fief +#AUTH_FIEF_URL=https://your-fief-instance.... +#AUTH_FIEF_CLIENT_ID=xxx +#AUTH_FIEF_CLIENT_SECRET=xxx + + ## ======================================================= ## Transcription backend ## diff --git a/server/poetry.lock b/server/poetry.lock index 9ad03bcf..b3b122f4 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -930,6 +930,23 @@ mysql = ["aiomysql"] postgresql = ["asyncpg"] sqlite = ["aiosqlite"] +[[package]] +name = "deprecated" +version = "1.2.14" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, + {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"}, +] + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] + [[package]] name = "dnspython" version = "2.4.1" @@ -1022,6 +1039,28 @@ tokenizers = "==0.13.*" conversion = ["transformers[torch] (>=4.23)"] dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"] +[[package]] +name = "fief-client" +version = "0.17.0" +description = "Fief Client for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fief_client-0.17.0-py3-none-any.whl", hash = "sha256:ecc8674ecaf58fc7d2926f5a0f49fabd3a1a03e278f030977a97ecb716b8884d"}, + {file = "fief_client-0.17.0.tar.gz", hash = "sha256:f1f9a10c760c29811a8cce2c1d58938090901772826dda973b67dde1bce3bafd"}, +] + +[package.dependencies] +fastapi = {version = "*", optional = true, markers = "extra == \"fastapi\""} +httpx = ">=0.21.3,<0.25.0" +jwcrypto = ">=1.4,<2.0.0" +makefun = {version = ">=1.14.0,<2.0.0", optional = true, markers = "extra == \"fastapi\""} + +[package.extras] +cli = ["halo"] +fastapi = ["fastapi", "makefun (>=1.14.0,<2.0.0)"] +flask = ["flask"] + [[package]] name = "filelock" version = "3.12.2" @@ -1530,6 +1569,20 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "jwcrypto" +version = "1.5.0" +description = "Implementation of JOSE Web standards" +optional = false +python-versions = ">= 3.6" +files = [ + {file = "jwcrypto-1.5.0.tar.gz", hash = "sha256:2c1dc51cf8e38ddf324795dfe9426dee9dd46caf47f535ccbc18781fba810b8d"}, +] + +[package.dependencies] +cryptography = ">=3.4" +deprecated = "*" + [[package]] name = "levenshtein" version = "0.21.1" @@ -1662,6 +1715,17 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"] +[[package]] +name = "makefun" +version = "1.15.1" +description = "Small library to dynamically create python functions." +optional = false +python-versions = "*" +files = [ + {file = "makefun-1.15.1-py2.py3-none-any.whl", hash = "sha256:a63cfc7b47a539c76d97bd4fdb833c7d0461e759fd1225f580cb4be6200294d4"}, + {file = "makefun-1.15.1.tar.gz", hash = "sha256:40b0f118b6ded0d8d78c78f1eb679b8b6b2462e3c1b3e05fb1b2da8cd46b48a5"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -3234,4 +3298,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "ea523f9b74581a7867097a6249d416d8836f4daaf33fde65ea343e4d3502c71c" +content-hash = "d84edfea8ac7a849340af8eb5db47df9c13a7cc1c640062ebedb2a808be0de4e" diff --git a/server/pyproject.toml b/server/pyproject.toml index e3e75843..895be79d 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -25,6 +25,7 @@ httpx = "^0.24.1" fastapi-pagination = "^0.12.6" databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"} sqlalchemy = "<1.5" +fief-client = {extras = ["fastapi"], version = "^0.17.0"} [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/app.py b/server/reflector/app.py index 8383bf32..fa148240 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi_pagination import add_pagination from fastapi.routing import APIRoute import reflector.db # noqa +import reflector.auth # noqa from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.transcripts import router as transcripts_router from reflector.events import subscribers_startup, subscribers_shutdown diff --git a/server/reflector/auth/__init__.py b/server/reflector/auth/__init__.py new file mode 100644 index 00000000..65e75d9b --- /dev/null +++ b/server/reflector/auth/__init__.py @@ -0,0 +1,13 @@ +from reflector.settings import settings +from reflector.logger import logger +import importlib + +logger.info(f"User authentication using {settings.AUTH_BACKEND}") +module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}" +auth_module = importlib.import_module(module_name) + +UserInfo = auth_module.UserInfo +AccessTokenInfo = auth_module.AccessTokenInfo +authenticated = auth_module.authenticated +current_user = auth_module.current_user +current_user_optional = auth_module.current_user_optional diff --git a/server/reflector/auth/auth_fief.py b/server/reflector/auth/auth_fief.py new file mode 100644 index 00000000..0b363fc0 --- /dev/null +++ b/server/reflector/auth/auth_fief.py @@ -0,0 +1,25 @@ +from fastapi.security import OAuth2AuthorizationCodeBearer +from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo +from fief_client.integrations.fastapi import FiefAuth +from reflector.settings import settings + +fief = FiefAsync( + settings.AUTH_FIEF_URL, + settings.AUTH_FIEF_CLIENT_ID, + settings.AUTH_FIEF_CLIENT_SECRET, +) + +scheme = OAuth2AuthorizationCodeBearer( + f"{settings.AUTH_FIEF_URL}/authorize", + f"{settings.AUTH_FIEF_URL}/api/token", + scopes={"openid": "openid", "offline_access": "offline_access"}, + auto_error=False, +) + +auth = FiefAuth(fief, scheme) + +UserInfo = FiefUserInfo +AccessTokenInfo = FiefAccessTokenInfo +authenticated = auth.authenticated() +current_user = auth.current_user() +current_user_optional = auth.current_user(optional=True) diff --git a/server/reflector/auth/auth_none.py b/server/reflector/auth/auth_none.py new file mode 100644 index 00000000..3959c739 --- /dev/null +++ b/server/reflector/auth/auth_none.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel +from typing import Annotated +from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +class UserInfo(BaseModel): + sub: str + + +class AccessTokenInfo(BaseModel): + pass + + +def authenticated(token: Annotated[str, Depends(oauth2_scheme)]): + def _authenticated(): + return None + + return _authenticated + + +def current_user(token: Annotated[str, Depends(oauth2_scheme)]): + def _current_user(): + return None + + return _current_user + + +def current_user_optional(token: Annotated[str, Depends(oauth2_scheme)]): + def _current_user_optional(): + return None + + return _current_user_optional diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 0787b466..2add7448 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -79,5 +79,13 @@ class Settings(BaseSettings): # Sentry SENTRY_DSN: str | None = None + # User authentication (none, fief) + AUTH_BACKEND: str = "none" + + # User authentication using fief + AUTH_FIEF_URL: str | None = None + AUTH_FIEF_CLIENT_ID: str | None = None + AUTH_FIEF_CLIENT_SECRET: str | None = None + settings = Settings() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 6f952938..778c47d7 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -4,6 +4,7 @@ from fastapi import ( Request, WebSocket, WebSocketDisconnect, + Depends, ) from fastapi.responses import FileResponse from starlette.concurrency import run_in_threadpool @@ -14,8 +15,9 @@ from fastapi_pagination import Page, paginate from reflector.logger import logger from reflector.db import database, transcripts from reflector.settings import settings +import reflector.auth as auth from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent -from typing import Optional +from typing import Annotated, Optional from pathlib import Path from tempfile import NamedTemporaryFile import av @@ -60,6 +62,7 @@ class TranscriptEvent(BaseModel): class Transcript(BaseModel): id: str = Field(default_factory=generate_uuid4) + user_id: str | None = None name: str = Field(default_factory=generate_transcript_name) status: str = "idle" locked: bool = False @@ -127,20 +130,28 @@ class Transcript(BaseModel): class TranscriptController: - async def get_all(self) -> list[Transcript]: + async def get_all(self, user_id: str | None = None) -> list[Transcript]: query = transcripts.select() + if user_id is not None: + query = query.where(transcripts.c.user_id == user_id) + print(query) results = await database.fetch_all(query) + print(results) return results - async def get_by_id(self, transcript_id: str) -> Transcript | None: + async def get_by_id( + self, transcript_id: str, user_id: str | None = None + ) -> Transcript | None: query = transcripts.select().where(transcripts.c.id == transcript_id) result = await database.fetch_one(query) if not result: return None + if user_id is not None and result["user_id"] != user_id: + return None return Transcript(**result) - async def add(self, name: str): - transcript = Transcript(name=name) + async def add(self, name: str, user_id: str | None = None): + transcript = Transcript(name=name, user_id=user_id) query = transcripts.insert().values(**transcript.model_dump()) await database.execute(query) return transcript @@ -155,10 +166,14 @@ class TranscriptController: for key, value in values.items(): setattr(transcript, key, value) - async def remove_by_id(self, transcript_id: str) -> None: + async def remove_by_id( + self, transcript_id: str, user_id: str | None = None + ) -> None: transcript = await self.get_by_id(transcript_id) if not transcript: return + if user_id is not None and transcript.user_id != user_id: + return transcript.unlink() query = transcripts.delete().where(transcripts.c.id == transcript_id) await database.execute(query) @@ -199,13 +214,19 @@ class DeletionStatus(BaseModel): @router.get("/transcripts", response_model=Page[GetTranscript]) -async def transcripts_list(): - return paginate(await transcripts_controller.get_all()) +async def transcripts_list( + user: auth.UserInfo = Depends(auth.current_user), +): + return paginate(await transcripts_controller.get_all(user_id=user["sub"])) @router.post("/transcripts", response_model=GetTranscript) -async def transcripts_create(info: CreateTranscript): - return await transcripts_controller.add(info.name) +async def transcripts_create( + info: CreateTranscript, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + return await transcripts_controller.add(info.name, user_id=user_id) # ============================================================== @@ -214,16 +235,25 @@ async def transcripts_create(info: CreateTranscript): @router.get("/transcripts/{transcript_id}", response_model=GetTranscript) -async def transcript_get(transcript_id: str): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_get( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") return transcript @router.patch("/transcripts/{transcript_id}", response_model=GetTranscript) -async def transcript_update(transcript_id: str, info: UpdateTranscript): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_update( + transcript_id: str, + info: UpdateTranscript, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") values = {} @@ -236,11 +266,15 @@ async def transcript_update(transcript_id: str, info: UpdateTranscript): @router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus) -async def transcript_delete(transcript_id: str): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_delete( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - await transcripts_controller.remove_by_id(transcript.id) + await transcripts_controller.remove_by_id(transcript.id, user_id=user_id) return DeletionStatus(status="ok") diff --git a/server/test.db b/server/test.db deleted file mode 100644 index 974a6a14..00000000 Binary files a/server/test.db and /dev/null differ diff --git a/server/tests/conftest.py b/server/tests/conftest.py new file mode 100644 index 00000000..d219a282 --- /dev/null +++ b/server/tests/conftest.py @@ -0,0 +1,16 @@ +import pytest + + +@pytest.fixture(scope="function", autouse=True) +@pytest.mark.asyncio +async def setup_database(): + from reflector.settings import settings + from tempfile import NamedTemporaryFile + + with NamedTemporaryFile() as f: + settings.DATABASE_URL = f"sqlite:///{f.name}" + from reflector.db import engine, metadata + + metadata.create_all(bind=engine) + + yield diff --git a/server/tests/test_transcripts.py b/server/tests/test_transcripts.py index 77cb4b23..6badc27d 100644 --- a/server/tests/test_transcripts.py +++ b/server/tests/test_transcripts.py @@ -1,10 +1,11 @@ import pytest from httpx import AsyncClient -from reflector.app import app @pytest.mark.asyncio async def test_transcript_create(): + from reflector.app import app + async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await ac.post("/transcripts", json={"name": "test"}) assert response.status_code == 200 @@ -21,6 +22,8 @@ async def test_transcript_create(): @pytest.mark.asyncio async def test_transcript_get_update_name(): + from reflector.app import app + async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await ac.post("/transcripts", json={"name": "test"}) assert response.status_code == 200 @@ -42,9 +45,35 @@ async def test_transcript_get_update_name(): @pytest.mark.asyncio -async def test_transcripts_list(): +async def test_transcripts_list_anonymous(): # XXX this test is a bit fragile, as it depends on the storage which # is shared between tests + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.get("/transcripts") + assert response.status_code == 401 + + +@pytest.fixture +@pytest.mark.asyncio +async def authenticated_client(): + from reflector.app import app + from reflector.auth import current_user, current_user_optional + + app.dependency_overrides[current_user] = lambda: {"sub": "randomuserid"} + app.dependency_overrides[current_user_optional] = lambda: {"sub": "randomuserid"} + yield + del app.dependency_overrides[current_user] + del app.dependency_overrides[current_user_optional] + + +@pytest.mark.asyncio +async def test_transcripts_list_authenticated(authenticated_client): + # XXX this test is a bit fragile, as it depends on the storage which + # is shared between tests + from reflector.app import app + async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await ac.post("/transcripts", json={"name": "testxx1"}) assert response.status_code == 200 @@ -64,6 +93,8 @@ async def test_transcripts_list(): @pytest.mark.asyncio async def test_transcript_delete(): + from reflector.app import app + async with AsyncClient(app=app, base_url="http://test/v1") as ac: response = await ac.post("/transcripts", json={"name": "testdel1"}) assert response.status_code == 200 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index f38728c2..0e764cca 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -8,7 +8,6 @@ import json from unittest.mock import patch from httpx import AsyncClient -from reflector.app import app from uvicorn import Config, Server import threading import asyncio @@ -76,6 +75,7 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm) # to be able to connect with aiortc from reflector.settings import settings + from reflector.app import app settings.DATA_DIR = Path(tmpdir)