mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: implement user authentication (none by default)
This commit is contained in:
@@ -11,6 +11,20 @@
|
||||
#DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector
|
||||
|
||||
|
||||
## =======================================================
|
||||
## User authentication
|
||||
## =======================================================
|
||||
|
||||
## No authentication
|
||||
#AUTH_BACKEND=none
|
||||
|
||||
## Using fief (fief.dev)
|
||||
#AUTH_BACKEND=fief
|
||||
#AUTH_FIEF_URL=https://your-fief-instance....
|
||||
#AUTH_FIEF_CLIENT_ID=xxx
|
||||
#AUTH_FIEF_CLIENT_SECRET=xxx
|
||||
|
||||
|
||||
## =======================================================
|
||||
## Transcription backend
|
||||
##
|
||||
|
||||
66
server/poetry.lock
generated
66
server/poetry.lock
generated
@@ -930,6 +930,23 @@ mysql = ["aiomysql"]
|
||||
postgresql = ["asyncpg"]
|
||||
sqlite = ["aiosqlite"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecated"
|
||||
version = "1.2.14"
|
||||
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
files = [
|
||||
{file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
|
||||
{file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
wrapt = ">=1.10,<2"
|
||||
|
||||
[package.extras]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "dnspython"
|
||||
version = "2.4.1"
|
||||
@@ -1022,6 +1039,28 @@ tokenizers = "==0.13.*"
|
||||
conversion = ["transformers[torch] (>=4.23)"]
|
||||
dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "fief-client"
|
||||
version = "0.17.0"
|
||||
description = "Fief Client for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "fief_client-0.17.0-py3-none-any.whl", hash = "sha256:ecc8674ecaf58fc7d2926f5a0f49fabd3a1a03e278f030977a97ecb716b8884d"},
|
||||
{file = "fief_client-0.17.0.tar.gz", hash = "sha256:f1f9a10c760c29811a8cce2c1d58938090901772826dda973b67dde1bce3bafd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fastapi = {version = "*", optional = true, markers = "extra == \"fastapi\""}
|
||||
httpx = ">=0.21.3,<0.25.0"
|
||||
jwcrypto = ">=1.4,<2.0.0"
|
||||
makefun = {version = ">=1.14.0,<2.0.0", optional = true, markers = "extra == \"fastapi\""}
|
||||
|
||||
[package.extras]
|
||||
cli = ["halo"]
|
||||
fastapi = ["fastapi", "makefun (>=1.14.0,<2.0.0)"]
|
||||
flask = ["flask"]
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.12.2"
|
||||
@@ -1530,6 +1569,20 @@ files = [
|
||||
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jwcrypto"
|
||||
version = "1.5.0"
|
||||
description = "Implementation of JOSE Web standards"
|
||||
optional = false
|
||||
python-versions = ">= 3.6"
|
||||
files = [
|
||||
{file = "jwcrypto-1.5.0.tar.gz", hash = "sha256:2c1dc51cf8e38ddf324795dfe9426dee9dd46caf47f535ccbc18781fba810b8d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = ">=3.4"
|
||||
deprecated = "*"
|
||||
|
||||
[[package]]
|
||||
name = "levenshtein"
|
||||
version = "0.21.1"
|
||||
@@ -1662,6 +1715,17 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
||||
[package.extras]
|
||||
dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "makefun"
|
||||
version = "1.15.1"
|
||||
description = "Small library to dynamically create python functions."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "makefun-1.15.1-py2.py3-none-any.whl", hash = "sha256:a63cfc7b47a539c76d97bd4fdb833c7d0461e759fd1225f580cb4be6200294d4"},
|
||||
{file = "makefun-1.15.1.tar.gz", hash = "sha256:40b0f118b6ded0d8d78c78f1eb679b8b6b2462e3c1b3e05fb1b2da8cd46b48a5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
@@ -3234,4 +3298,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "ea523f9b74581a7867097a6249d416d8836f4daaf33fde65ea343e4d3502c71c"
|
||||
content-hash = "d84edfea8ac7a849340af8eb5db47df9c13a7cc1c640062ebedb2a808be0de4e"
|
||||
|
||||
@@ -25,6 +25,7 @@ httpx = "^0.24.1"
|
||||
fastapi-pagination = "^0.12.6"
|
||||
databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"}
|
||||
sqlalchemy = "<1.5"
|
||||
fief-client = {extras = ["fastapi"], version = "^0.17.0"}
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_pagination import add_pagination
|
||||
from fastapi.routing import APIRoute
|
||||
import reflector.db # noqa
|
||||
import reflector.auth # noqa
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
from reflector.views.transcripts import router as transcripts_router
|
||||
from reflector.events import subscribers_startup, subscribers_shutdown
|
||||
|
||||
13
server/reflector/auth/__init__.py
Normal file
13
server/reflector/auth/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from reflector.settings import settings
|
||||
from reflector.logger import logger
|
||||
import importlib
|
||||
|
||||
logger.info(f"User authentication using {settings.AUTH_BACKEND}")
|
||||
module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}"
|
||||
auth_module = importlib.import_module(module_name)
|
||||
|
||||
UserInfo = auth_module.UserInfo
|
||||
AccessTokenInfo = auth_module.AccessTokenInfo
|
||||
authenticated = auth_module.authenticated
|
||||
current_user = auth_module.current_user
|
||||
current_user_optional = auth_module.current_user_optional
|
||||
25
server/reflector/auth/auth_fief.py
Normal file
25
server/reflector/auth/auth_fief.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo
|
||||
from fief_client.integrations.fastapi import FiefAuth
|
||||
from reflector.settings import settings
|
||||
|
||||
fief = FiefAsync(
|
||||
settings.AUTH_FIEF_URL,
|
||||
settings.AUTH_FIEF_CLIENT_ID,
|
||||
settings.AUTH_FIEF_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
scheme = OAuth2AuthorizationCodeBearer(
|
||||
f"{settings.AUTH_FIEF_URL}/authorize",
|
||||
f"{settings.AUTH_FIEF_URL}/api/token",
|
||||
scopes={"openid": "openid", "offline_access": "offline_access"},
|
||||
auto_error=False,
|
||||
)
|
||||
|
||||
auth = FiefAuth(fief, scheme)
|
||||
|
||||
UserInfo = FiefUserInfo
|
||||
AccessTokenInfo = FiefAccessTokenInfo
|
||||
authenticated = auth.authenticated()
|
||||
current_user = auth.current_user()
|
||||
current_user_optional = auth.current_user(optional=True)
|
||||
35
server/reflector/auth/auth_none.py
Normal file
35
server/reflector/auth/auth_none.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
sub: str
|
||||
|
||||
|
||||
class AccessTokenInfo(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
def authenticated(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
def _authenticated():
|
||||
return None
|
||||
|
||||
return _authenticated
|
||||
|
||||
|
||||
def current_user(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
def _current_user():
|
||||
return None
|
||||
|
||||
return _current_user
|
||||
|
||||
|
||||
def current_user_optional(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
def _current_user_optional():
|
||||
return None
|
||||
|
||||
return _current_user_optional
|
||||
@@ -79,5 +79,13 @@ class Settings(BaseSettings):
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
# User authentication (none, fief)
|
||||
AUTH_BACKEND: str = "none"
|
||||
|
||||
# User authentication using fief
|
||||
AUTH_FIEF_URL: str | None = None
|
||||
AUTH_FIEF_CLIENT_ID: str | None = None
|
||||
AUTH_FIEF_CLIENT_SECRET: str | None = None
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import (
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
Depends,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
@@ -14,8 +15,9 @@ from fastapi_pagination import Page, paginate
|
||||
from reflector.logger import logger
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.settings import settings
|
||||
import reflector.auth as auth
|
||||
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
import av
|
||||
@@ -60,6 +62,7 @@ class TranscriptEvent(BaseModel):
|
||||
|
||||
class Transcript(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
user_id: str | None = None
|
||||
name: str = Field(default_factory=generate_transcript_name)
|
||||
status: str = "idle"
|
||||
locked: bool = False
|
||||
@@ -127,20 +130,28 @@ class Transcript(BaseModel):
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self) -> list[Transcript]:
|
||||
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
||||
query = transcripts.select()
|
||||
if user_id is not None:
|
||||
query = query.where(transcripts.c.user_id == user_id)
|
||||
print(query)
|
||||
results = await database.fetch_all(query)
|
||||
print(results)
|
||||
return results
|
||||
|
||||
async def get_by_id(self, transcript_id: str) -> Transcript | None:
|
||||
async def get_by_id(
|
||||
self, transcript_id: str, user_id: str | None = None
|
||||
) -> Transcript | None:
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
if user_id is not None and result["user_id"] != user_id:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def add(self, name: str):
|
||||
transcript = Transcript(name=name)
|
||||
async def add(self, name: str, user_id: str | None = None):
|
||||
transcript = Transcript(name=name, user_id=user_id)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
@@ -155,10 +166,14 @@ class TranscriptController:
|
||||
for key, value in values.items():
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(self, transcript_id: str) -> None:
|
||||
async def remove_by_id(
|
||||
self, transcript_id: str, user_id: str | None = None
|
||||
) -> None:
|
||||
transcript = await self.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
if user_id is not None and transcript.user_id != user_id:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
@@ -199,13 +214,19 @@ class DeletionStatus(BaseModel):
|
||||
|
||||
|
||||
@router.get("/transcripts", response_model=Page[GetTranscript])
|
||||
async def transcripts_list():
|
||||
return paginate(await transcripts_controller.get_all())
|
||||
async def transcripts_list(
|
||||
user: auth.UserInfo = Depends(auth.current_user),
|
||||
):
|
||||
return paginate(await transcripts_controller.get_all(user_id=user["sub"]))
|
||||
|
||||
|
||||
@router.post("/transcripts", response_model=GetTranscript)
|
||||
async def transcripts_create(info: CreateTranscript):
|
||||
return await transcripts_controller.add(info.name)
|
||||
async def transcripts_create(
|
||||
info: CreateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
return await transcripts_controller.add(info.name, user_id=user_id)
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -214,16 +235,25 @@ async def transcripts_create(info: CreateTranscript):
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_get(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_get(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
async def transcript_update(transcript_id: str, info: UpdateTranscript):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_update(
|
||||
transcript_id: str,
|
||||
info: UpdateTranscript,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = {}
|
||||
@@ -236,11 +266,15 @@ async def transcript_update(transcript_id: str, info: UpdateTranscript):
|
||||
|
||||
|
||||
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
|
||||
async def transcript_delete(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
async def transcript_delete(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
await transcripts_controller.remove_by_id(transcript.id)
|
||||
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
|
||||
BIN
server/test.db
BIN
server/test.db
Binary file not shown.
16
server/tests/conftest.py
Normal file
16
server/tests/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@pytest.mark.asyncio
|
||||
async def setup_database():
|
||||
from reflector.settings import settings
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile() as f:
|
||||
settings.DATABASE_URL = f"sqlite:///{f.name}"
|
||||
from reflector.db import engine, metadata
|
||||
|
||||
metadata.create_all(bind=engine)
|
||||
|
||||
yield
|
||||
@@ -1,10 +1,11 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from reflector.app import app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_create():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
@@ -21,6 +22,8 @@ async def test_transcript_create():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_get_update_name():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
@@ -42,9 +45,35 @@ async def test_transcript_get_update_name():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcripts_list():
|
||||
async def test_transcripts_list_anonymous():
|
||||
# XXX this test is a bit fragile, as it depends on the storage which
|
||||
# is shared between tests
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.get("/transcripts")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def authenticated_client():
|
||||
from reflector.app import app
|
||||
from reflector.auth import current_user, current_user_optional
|
||||
|
||||
app.dependency_overrides[current_user] = lambda: {"sub": "randomuserid"}
|
||||
app.dependency_overrides[current_user_optional] = lambda: {"sub": "randomuserid"}
|
||||
yield
|
||||
del app.dependency_overrides[current_user]
|
||||
del app.dependency_overrides[current_user_optional]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcripts_list_authenticated(authenticated_client):
|
||||
# XXX this test is a bit fragile, as it depends on the storage which
|
||||
# is shared between tests
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "testxx1"})
|
||||
assert response.status_code == 200
|
||||
@@ -64,6 +93,8 @@ async def test_transcripts_list():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_delete():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "testdel1"})
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -8,7 +8,6 @@ import json
|
||||
from unittest.mock import patch
|
||||
from httpx import AsyncClient
|
||||
|
||||
from reflector.app import app
|
||||
from uvicorn import Config, Server
|
||||
import threading
|
||||
import asyncio
|
||||
@@ -76,6 +75,7 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
||||
# to be able to connect with aiortc
|
||||
|
||||
from reflector.settings import settings
|
||||
from reflector.app import app
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user