mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
feat: Add Single User authentication to Selfhosted (#870)
* Single user/password for selfhosted * fix revision id latest migration
This commit is contained in:
committed by
GitHub
parent
2ba0d965e8
commit
c8db37362b
@@ -73,10 +73,10 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
|
||||
## Setup: ./scripts/setup-standalone.sh
|
||||
## Mac: Ollama runs natively (Metal GPU). Containers reach it via host.docker.internal.
|
||||
## Linux: docker compose --profile ollama-gpu up -d (or ollama-cpu for no GPU)
|
||||
LLM_URL=http://host.docker.internal:11434/v1
|
||||
LLM_URL=http://host.docker.internal:11435/v1
|
||||
LLM_MODEL=qwen2.5:14b
|
||||
LLM_API_KEY=not-needed
|
||||
## Linux with containerized Ollama: LLM_URL=http://ollama:11434/v1
|
||||
## Linux with containerized Ollama: LLM_URL=http://ollama:11435/v1
|
||||
|
||||
## --- Option B: Remote/cloud LLM ---
|
||||
#LLM_API_KEY=sk-your-openai-api-key
|
||||
|
||||
@@ -26,6 +26,9 @@ SECRET_KEY=changeme-generate-a-secure-random-string
|
||||
AUTH_BACKEND=none
|
||||
# AUTH_BACKEND=jwt
|
||||
# AUTH_JWT_AUDIENCE=
|
||||
# AUTH_BACKEND=password
|
||||
# ADMIN_EMAIL=admin@localhost
|
||||
# ADMIN_PASSWORD_HASH=pbkdf2:sha256:100000$<salt>$<hash>
|
||||
|
||||
# =======================================================
|
||||
# Specialized Models (Transcription, Diarization, Translation)
|
||||
@@ -64,7 +67,7 @@ TRANSLATE_URL=http://transcription:8000
|
||||
# LLM_MODEL=gpt-4o-mini
|
||||
|
||||
# --- Option B: Local Ollama (auto-set by --ollama-gpu/--ollama-cpu) ---
|
||||
# LLM_URL=http://ollama:11434/v1
|
||||
# LLM_URL=http://ollama:11435/v1
|
||||
# LLM_API_KEY=not-needed
|
||||
# LLM_MODEL=llama3.1
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
"""add password_hash to user table
|
||||
|
||||
Revision ID: e1f093f7f124
|
||||
Revises: 623af934249a
|
||||
Create Date: 2026-02-19 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "e1f093f7f124"
|
||||
down_revision: Union[str, None] = "623af934249a"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("password_hash", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "password_hash")
|
||||
@@ -8,6 +8,7 @@ from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from reflector.auth import router as auth_router
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.logger import logger
|
||||
from reflector.metrics import metrics_init
|
||||
@@ -105,6 +106,8 @@ app.include_router(user_ws_router, prefix="/v1")
|
||||
app.include_router(zulip_router, prefix="/v1")
|
||||
app.include_router(whereby_router, prefix="/v1")
|
||||
app.include_router(daily_router, prefix="/v1/daily")
|
||||
if auth_router:
|
||||
app.include_router(auth_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
# prepare celery
|
||||
|
||||
@@ -14,3 +14,6 @@ current_user = auth_module.current_user
|
||||
current_user_optional = auth_module.current_user_optional
|
||||
parse_ws_bearer_token = auth_module.parse_ws_bearer_token
|
||||
current_user_ws_optional = auth_module.current_user_ws_optional
|
||||
|
||||
# Optional router (e.g. for /auth/login in password backend)
|
||||
router = getattr(auth_module, "router", None)
|
||||
|
||||
198
server/reflector/auth/auth_password.py
Normal file
198
server/reflector/auth/auth_password.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Password-based authentication backend for selfhosted deployments.
|
||||
|
||||
Issues HS256 JWTs signed with settings.SECRET_KEY. Provides a POST /auth/login
|
||||
endpoint for email/password authentication.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.auth.password_utils import verify_password
|
||||
from reflector.db.user_api_keys import user_api_keys_controller
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
||||
# --- FastAPI security schemes (same pattern as auth_jwt.py) ---
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/auth/login", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
# --- JWT configuration ---
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24 hours
|
||||
|
||||
# --- Rate limiting (in-memory) ---
|
||||
_login_attempts: dict[str, list[float]] = defaultdict(list)
|
||||
RATE_LIMIT_WINDOW = 300 # 5 minutes
|
||||
RATE_LIMIT_MAX = 10 # max attempts per window
|
||||
|
||||
|
||||
def _check_rate_limit(key: str) -> bool:
|
||||
"""Return True if request is allowed, False if rate-limited."""
|
||||
now = time.monotonic()
|
||||
attempts = _login_attempts[key]
|
||||
_login_attempts[key] = [t for t in attempts if now - t < RATE_LIMIT_WINDOW]
|
||||
if len(_login_attempts[key]) >= RATE_LIMIT_MAX:
|
||||
return False
|
||||
_login_attempts[key].append(now)
|
||||
return True
|
||||
|
||||
|
||||
# --- Pydantic models ---
|
||||
class UserInfo(BaseModel):
|
||||
sub: str
|
||||
email: Optional[str] = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class AccessTokenInfo(BaseModel):
|
||||
exp: Optional[int] = None
|
||||
sub: Optional[str] = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
# --- JWT token creation and verification ---
|
||||
def _create_access_token(user_id: str, email: str) -> tuple[str, int]:
|
||||
"""Create an HS256 JWT. Returns (token, expires_in_seconds)."""
|
||||
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"exp": expire,
|
||||
}
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return token, int(expires_delta.total_seconds())
|
||||
|
||||
|
||||
def _verify_token(token: str) -> dict:
|
||||
"""Verify and decode an HS256 JWT."""
|
||||
return jwt.decode(token, settings.SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
|
||||
# --- Authentication logic (mirrors auth_jwt._authenticate_user) ---
|
||||
async def _authenticate_user(
|
||||
jwt_token: Optional[str],
|
||||
api_key: Optional[str],
|
||||
) -> UserInfo | None:
|
||||
user_infos: list[UserInfo] = []
|
||||
|
||||
if api_key:
|
||||
user_api_key = await user_api_keys_controller.verify_key(api_key)
|
||||
if user_api_key:
|
||||
user_infos.append(UserInfo(sub=user_api_key.user_id, email=None))
|
||||
|
||||
if jwt_token:
|
||||
try:
|
||||
payload = _verify_token(jwt_token)
|
||||
user_id = payload["sub"]
|
||||
email = payload.get("email")
|
||||
user_infos.append(UserInfo(sub=user_id, email=email))
|
||||
except JWTError as e:
|
||||
logger.error(f"JWT error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||
|
||||
if len(user_infos) == 0:
|
||||
return None
|
||||
|
||||
if len(set(x.sub for x in user_infos)) > 1:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authentication: more than one user provided",
|
||||
)
|
||||
|
||||
return user_infos[0]
|
||||
|
||||
|
||||
# --- FastAPI dependencies (exported, required by auth/__init__.py) ---
|
||||
def authenticated(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
if token is None:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return None
|
||||
|
||||
|
||||
async def current_user(
|
||||
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||
):
|
||||
user = await _authenticate_user(jwt_token, api_key)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_optional(
|
||||
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||
):
|
||||
return await _authenticate_user(jwt_token, api_key)
|
||||
|
||||
|
||||
# --- WebSocket auth (same pattern as auth_jwt.py) ---
|
||||
def parse_ws_bearer_token(
|
||||
websocket: "WebSocket",
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
raw = websocket.headers.get("sec-websocket-protocol") or ""
|
||||
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
||||
if len(parts) >= 2 and parts[0].lower() == "bearer":
|
||||
return parts[1], "bearer"
|
||||
return None, None
|
||||
|
||||
|
||||
async def current_user_ws_optional(websocket: "WebSocket") -> Optional[UserInfo]:
|
||||
token, _ = parse_ws_bearer_token(websocket)
|
||||
if not token:
|
||||
return None
|
||||
return await _authenticate_user(token, None)
|
||||
|
||||
|
||||
# --- Login router ---
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(request: Request, body: LoginRequest):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
if not _check_rate_limit(client_ip):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many login attempts. Try again later.",
|
||||
)
|
||||
|
||||
user = await user_controller.get_by_email(body.email)
|
||||
if not user or not user.password_hash:
|
||||
print("invalid email")
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
if not verify_password(body.password, user.password_hash):
|
||||
print("invalid pass")
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
access_token, expires_in = _create_access_token(user.id, user.email)
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=expires_in,
|
||||
)
|
||||
41
server/reflector/auth/password_utils.py
Normal file
41
server/reflector/auth/password_utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Password hashing utilities using PBKDF2-SHA256 (stdlib only)."""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
|
||||
PBKDF2_ITERATIONS = 100_000
|
||||
SALT_LENGTH = 16 # bytes, hex-encoded to 32 chars
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using PBKDF2-SHA256 with a random salt.
|
||||
|
||||
Format: pbkdf2:sha256:<iterations>$<salt_hex>$<hash_hex>
|
||||
"""
|
||||
salt = os.urandom(SALT_LENGTH).hex()
|
||||
dk = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
password.encode("utf-8"),
|
||||
salt.encode("utf-8"),
|
||||
PBKDF2_ITERATIONS,
|
||||
)
|
||||
return f"pbkdf2:sha256:{PBKDF2_ITERATIONS}${salt}${dk.hex()}"
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a password against its hash using constant-time comparison."""
|
||||
try:
|
||||
header, salt, stored_hash = password_hash.split("$", 2)
|
||||
_, algo, iterations_str = header.split(":")
|
||||
iterations = int(iterations_str)
|
||||
|
||||
dk = hashlib.pbkdf2_hmac(
|
||||
algo,
|
||||
password.encode("utf-8"),
|
||||
salt.encode("utf-8"),
|
||||
iterations,
|
||||
)
|
||||
return hmac.compare_digest(dk.hex(), stored_hash)
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
@@ -1,4 +1,4 @@
|
||||
"""User table for storing Authentik user information."""
|
||||
"""User table for storing user information."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -15,6 +15,7 @@ users = sqlalchemy.Table(
|
||||
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||
sqlalchemy.Column("email", sqlalchemy.String, nullable=False),
|
||||
sqlalchemy.Column("authentik_uid", sqlalchemy.String, nullable=False),
|
||||
sqlalchemy.Column("password_hash", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||
sqlalchemy.Column("updated_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||
sqlalchemy.Index("idx_user_authentik_uid", "authentik_uid", unique=True),
|
||||
@@ -26,6 +27,7 @@ class User(BaseModel):
|
||||
id: NonEmptyString = Field(default_factory=generate_uuid4)
|
||||
email: NonEmptyString
|
||||
authentik_uid: NonEmptyString
|
||||
password_hash: str | None = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@@ -51,22 +53,29 @@ class UserController:
|
||||
|
||||
@staticmethod
|
||||
async def create_or_update(
|
||||
id: NonEmptyString, authentik_uid: NonEmptyString, email: NonEmptyString
|
||||
id: NonEmptyString,
|
||||
authentik_uid: NonEmptyString,
|
||||
email: NonEmptyString,
|
||||
password_hash: str | None = None,
|
||||
) -> User:
|
||||
existing = await UserController.get_by_authentik_uid(authentik_uid)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if existing:
|
||||
update_values: dict = {"email": email, "updated_at": now}
|
||||
if password_hash is not None:
|
||||
update_values["password_hash"] = password_hash
|
||||
query = (
|
||||
users.update()
|
||||
.where(users.c.authentik_uid == authentik_uid)
|
||||
.values(email=email, updated_at=now)
|
||||
.values(**update_values)
|
||||
)
|
||||
await get_database().execute(query)
|
||||
return User(
|
||||
id=existing.id,
|
||||
authentik_uid=authentik_uid,
|
||||
email=email,
|
||||
password_hash=password_hash or existing.password_hash,
|
||||
created_at=existing.created_at,
|
||||
updated_at=now,
|
||||
)
|
||||
@@ -75,6 +84,7 @@ class UserController:
|
||||
id=id,
|
||||
authentik_uid=authentik_uid,
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
@@ -82,6 +92,16 @@ class UserController:
|
||||
await get_database().execute(query)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def set_password_hash(user_id: NonEmptyString, password_hash: str) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
query = (
|
||||
users.update()
|
||||
.where(users.c.id == user_id)
|
||||
.values(password_hash=password_hash, updated_at=now)
|
||||
)
|
||||
await get_database().execute(query)
|
||||
|
||||
@staticmethod
|
||||
async def list_all() -> list[User]:
|
||||
query = users.select().order_by(users.c.created_at.desc())
|
||||
|
||||
@@ -228,6 +228,7 @@ class LLM:
|
||||
is_function_calling_model=False,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
timeout=self.settings_obj.LLM_REQUEST_TIMEOUT,
|
||||
additional_kwargs={"extra_body": {"litellm_session_id": session_id}},
|
||||
)
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ class Settings(BaseSettings):
|
||||
LLM_URL: str | None = None
|
||||
LLM_API_KEY: str | None = None
|
||||
LLM_CONTEXT_WINDOW: int = 16000
|
||||
LLM_REQUEST_TIMEOUT: float = 300.0 # HTTP request timeout for LLM calls (seconds)
|
||||
|
||||
LLM_PARSE_MAX_RETRIES: int = (
|
||||
3 # Max retries for JSON/validation errors (total attempts = retries + 1)
|
||||
@@ -112,7 +113,7 @@ class Settings(BaseSettings):
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
# User authentication (none, jwt)
|
||||
# User authentication (none, jwt, password)
|
||||
AUTH_BACKEND: str = "none"
|
||||
|
||||
# User authentication using JWT
|
||||
@@ -120,6 +121,10 @@ class Settings(BaseSettings):
|
||||
AUTH_JWT_PUBLIC_KEY: str | None = "authentik.monadical.com_public.pem"
|
||||
AUTH_JWT_AUDIENCE: str | None = None
|
||||
|
||||
# User authentication using password (selfhosted)
|
||||
ADMIN_EMAIL: str | None = None
|
||||
ADMIN_PASSWORD_HASH: str | None = None
|
||||
|
||||
PUBLIC_MODE: bool = False
|
||||
PUBLIC_DATA_RETENTION_DAYS: PositiveInt = 7
|
||||
|
||||
@@ -153,6 +158,9 @@ class Settings(BaseSettings):
|
||||
WHEREBY_WEBHOOK_SECRET: str | None = None
|
||||
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
|
||||
SQS_POLLING_TIMEOUT_SECONDS: int = 60
|
||||
CELERY_BEAT_POLL_INTERVAL: int = (
|
||||
0 # 0 = use individual defaults; set e.g. 300 for 5-min polling
|
||||
)
|
||||
|
||||
# Daily.co integration
|
||||
DAILY_API_KEY: str | None = None
|
||||
|
||||
80
server/reflector/tools/create_admin.py
Normal file
80
server/reflector/tools/create_admin.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Create or update an admin user with password authentication.
|
||||
|
||||
Usage:
|
||||
uv run python -m reflector.tools.create_admin --email admin@localhost --password <pass>
|
||||
uv run python -m reflector.tools.create_admin --email admin@localhost # prompts for password
|
||||
uv run python -m reflector.tools.create_admin --hash-only --password <pass> # print hash only
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import getpass
|
||||
import sys
|
||||
|
||||
from reflector.auth.password_utils import hash_password
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
async def create_admin(email: str, password: str) -> None:
|
||||
from reflector.db import get_database
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
password_hash = hash_password(password)
|
||||
|
||||
existing = await user_controller.get_by_email(email)
|
||||
if existing:
|
||||
await user_controller.set_password_hash(existing.id, password_hash)
|
||||
print(f"Updated password for existing user: {email} (id={existing.id})")
|
||||
else:
|
||||
user = await user_controller.create_or_update(
|
||||
id=generate_uuid4(),
|
||||
authentik_uid=f"local:{email}",
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
)
|
||||
print(f"Created admin user: {email} (id={user.id})")
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Create or update an admin user")
|
||||
parser.add_argument(
|
||||
"--email", default="admin@localhost", help="Admin email address"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--password",
|
||||
help="Admin password (will prompt if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hash-only",
|
||||
action="store_true",
|
||||
help="Print the password hash and exit (for ADMIN_PASSWORD_HASH env var)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
password = args.password
|
||||
if not password:
|
||||
password = getpass.getpass("Password: ")
|
||||
confirm = getpass.getpass("Confirm password: ")
|
||||
if password != confirm:
|
||||
print("Passwords do not match", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not password:
|
||||
print("Password cannot be empty", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.hash_only:
|
||||
print(hash_password(password))
|
||||
sys.exit(0)
|
||||
|
||||
asyncio.run(create_admin(args.email, password))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
server/reflector/tools/provision_admin.py
Normal file
43
server/reflector/tools/provision_admin.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Provision admin user on server startup using environment variables.
|
||||
|
||||
Reads ADMIN_EMAIL and ADMIN_PASSWORD_HASH from settings and creates or updates
|
||||
the admin user. Intended to be called from runserver.sh on container startup.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.settings import settings
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
async def provision() -> None:
|
||||
if not settings.ADMIN_EMAIL or not settings.ADMIN_PASSWORD_HASH:
|
||||
return
|
||||
|
||||
from reflector.db import get_database
|
||||
|
||||
database = get_database()
|
||||
await database.connect()
|
||||
|
||||
try:
|
||||
existing = await user_controller.get_by_email(settings.ADMIN_EMAIL)
|
||||
if existing:
|
||||
await user_controller.set_password_hash(
|
||||
existing.id, settings.ADMIN_PASSWORD_HASH
|
||||
)
|
||||
print(f"Updated admin user: {settings.ADMIN_EMAIL}")
|
||||
else:
|
||||
await user_controller.create_or_update(
|
||||
id=generate_uuid4(),
|
||||
authentik_uid=f"local:{settings.ADMIN_EMAIL}",
|
||||
email=settings.ADMIN_EMAIL,
|
||||
password_hash=settings.ADMIN_PASSWORD_HASH,
|
||||
)
|
||||
print(f"Created admin user: {settings.ADMIN_EMAIL}")
|
||||
finally:
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(provision())
|
||||
@@ -2,8 +2,7 @@ from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from reflector.auth.auth_jwt import JWTAuth # type: ignore
|
||||
from reflector.db.users import user_controller
|
||||
import reflector.auth as auth
|
||||
from reflector.ws_events import UserWsEvent
|
||||
from reflector.ws_manager import get_ws_manager
|
||||
|
||||
@@ -26,42 +25,24 @@ UNAUTHORISED = 4401
|
||||
|
||||
@router.websocket("/events")
|
||||
async def user_events_websocket(websocket: WebSocket):
|
||||
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
|
||||
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
|
||||
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
|
||||
token: Optional[str] = None
|
||||
negotiated_subprotocol: Optional[str] = None
|
||||
if len(parts) >= 2 and parts[0].lower() == "bearer":
|
||||
negotiated_subprotocol = "bearer"
|
||||
token = parts[1]
|
||||
token, negotiated_subprotocol = auth.parse_ws_bearer_token(websocket)
|
||||
|
||||
user_id: Optional[str] = None
|
||||
if not token:
|
||||
await websocket.close(code=UNAUTHORISED)
|
||||
return
|
||||
|
||||
try:
|
||||
payload = JWTAuth().verify_token(token)
|
||||
authentik_uid = payload.get("sub")
|
||||
|
||||
if authentik_uid:
|
||||
user = await user_controller.get_by_authentik_uid(authentik_uid)
|
||||
if user:
|
||||
user_id = user.id
|
||||
else:
|
||||
await websocket.close(code=UNAUTHORISED)
|
||||
return
|
||||
else:
|
||||
await websocket.close(code=UNAUTHORISED)
|
||||
return
|
||||
user = await auth.current_user_ws_optional(websocket)
|
||||
except Exception:
|
||||
await websocket.close(code=UNAUTHORISED)
|
||||
return
|
||||
|
||||
if not user_id:
|
||||
if not user:
|
||||
await websocket.close(code=UNAUTHORISED)
|
||||
return
|
||||
|
||||
user_id: Optional[str] = user.sub if hasattr(user, "sub") else user["sub"]
|
||||
|
||||
room_id = f"user:{user_id}"
|
||||
ws_manager = get_ws_manager()
|
||||
|
||||
|
||||
@@ -8,8 +8,21 @@ from reflector.settings import settings
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Polling intervals (seconds)
|
||||
# CELERY_BEAT_POLL_INTERVAL overrides all sub-5-min intervals (e.g. 300 for selfhosted)
|
||||
_override = (
|
||||
float(settings.CELERY_BEAT_POLL_INTERVAL)
|
||||
if settings.CELERY_BEAT_POLL_INTERVAL > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
# Webhook-aware: 180s when webhook configured (backup mode), 15s when no webhook (primary discovery)
|
||||
POLL_DAILY_RECORDINGS_INTERVAL_SEC = 180.0 if settings.DAILY_WEBHOOK_SECRET else 15.0
|
||||
POLL_DAILY_RECORDINGS_INTERVAL_SEC = _override or (
|
||||
180.0 if settings.DAILY_WEBHOOK_SECRET else 15.0
|
||||
)
|
||||
SQS_POLL_INTERVAL = _override or float(settings.SQS_POLLING_TIMEOUT_SECONDS)
|
||||
RECONCILIATION_INTERVAL = _override or 30.0
|
||||
ICS_SYNC_INTERVAL = _override or 60.0
|
||||
UPCOMING_MEETINGS_INTERVAL = _override or 30.0
|
||||
|
||||
if celery.current_app.main != "default":
|
||||
logger.info(f"Celery already configured ({celery.current_app})")
|
||||
@@ -33,11 +46,11 @@ else:
|
||||
app.conf.beat_schedule = {
|
||||
"process_messages": {
|
||||
"task": "reflector.worker.process.process_messages",
|
||||
"schedule": float(settings.SQS_POLLING_TIMEOUT_SECONDS),
|
||||
"schedule": SQS_POLL_INTERVAL,
|
||||
},
|
||||
"process_meetings": {
|
||||
"task": "reflector.worker.process.process_meetings",
|
||||
"schedule": float(settings.SQS_POLLING_TIMEOUT_SECONDS),
|
||||
"schedule": SQS_POLL_INTERVAL,
|
||||
},
|
||||
"reprocess_failed_recordings": {
|
||||
"task": "reflector.worker.process.reprocess_failed_recordings",
|
||||
@@ -53,15 +66,15 @@ else:
|
||||
},
|
||||
"trigger_daily_reconciliation": {
|
||||
"task": "reflector.worker.process.trigger_daily_reconciliation",
|
||||
"schedule": 30.0, # Every 30 seconds (queues poll tasks for all active meetings)
|
||||
"schedule": RECONCILIATION_INTERVAL,
|
||||
},
|
||||
"sync_all_ics_calendars": {
|
||||
"task": "reflector.worker.ics_sync.sync_all_ics_calendars",
|
||||
"schedule": 60.0, # Run every minute to check which rooms need sync
|
||||
"schedule": ICS_SYNC_INTERVAL,
|
||||
},
|
||||
"create_upcoming_meetings": {
|
||||
"task": "reflector.worker.ics_sync.create_upcoming_meetings",
|
||||
"schedule": 30.0, # Run every 30 seconds to create upcoming meetings
|
||||
"schedule": UPCOMING_MEETINGS_INTERVAL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
if [ "${ENTRYPOINT}" = "server" ]; then
|
||||
uv run alembic upgrade head
|
||||
# Provision admin user if password auth is configured
|
||||
if [ -n "${ADMIN_EMAIL:-}" ] && [ -n "${ADMIN_PASSWORD_HASH:-}" ]; then
|
||||
uv run python -m reflector.tools.provision_admin
|
||||
fi
|
||||
uv run uvicorn reflector.app:app --host 0.0.0.0 --port 1250
|
||||
elif [ "${ENTRYPOINT}" = "worker" ]; then
|
||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||
|
||||
201
server/tests/test_auth_password.py
Normal file
201
server/tests/test_auth_password.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Tests for the password auth backend."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from jose import jwt
|
||||
|
||||
from reflector.auth.password_utils import hash_password
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def password_app():
|
||||
"""Create a minimal FastAPI app with the password auth router."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from reflector.auth import auth_password
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(auth_password.router, prefix="/v1")
|
||||
# Reset rate limiter between tests
|
||||
auth_password._login_attempts.clear()
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def password_client(password_app):
|
||||
"""Create a test client for the password auth app."""
|
||||
async with AsyncClient(app=password_app, base_url="http://test/v1") as client:
|
||||
yield client
|
||||
|
||||
|
||||
async def _create_user_with_password(email: str, password: str):
|
||||
"""Helper to create a user with a password hash in the DB."""
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
pw_hash = hash_password(password)
|
||||
return await user_controller.create_or_update(
|
||||
id=generate_uuid4(),
|
||||
authentik_uid=f"local:{email}",
|
||||
email=email,
|
||||
password_hash=pw_hash,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(password_client, setup_database):
|
||||
await _create_user_with_password("admin@test.com", "testpass123")
|
||||
|
||||
response = await password_client.post(
|
||||
"/auth/login",
|
||||
json={"email": "admin@test.com", "password": "testpass123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
assert data["expires_in"] > 0
|
||||
|
||||
# Verify the JWT is valid
|
||||
payload = jwt.decode(
|
||||
data["access_token"],
|
||||
settings.SECRET_KEY,
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
assert payload["email"] == "admin@test.com"
|
||||
assert "sub" in payload
|
||||
assert "exp" in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(password_client, setup_database):
|
||||
await _create_user_with_password("user@test.com", "correctpassword")
|
||||
|
||||
response = await password_client.post(
|
||||
"/auth/login",
|
||||
json={"email": "user@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_nonexistent_user(password_client, setup_database):
|
||||
response = await password_client.post(
|
||||
"/auth/login",
|
||||
json={"email": "nobody@test.com", "password": "anything"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_without_password_hash(password_client, setup_database):
|
||||
"""User exists but has no password_hash (e.g. Authentik user)."""
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
await user_controller.create_or_update(
|
||||
id=generate_uuid4(),
|
||||
authentik_uid="authentik:abc123",
|
||||
email="oidc@test.com",
|
||||
)
|
||||
|
||||
response = await password_client.post(
|
||||
"/auth/login",
|
||||
json={"email": "oidc@test.com", "password": "anything"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_rate_limiting(password_client, setup_database):
|
||||
from reflector.auth import auth_password
|
||||
|
||||
# Reset rate limiter
|
||||
auth_password._login_attempts.clear()
|
||||
|
||||
for _ in range(10):
|
||||
await password_client.post(
|
||||
"/auth/login",
|
||||
json={"email": "fake@test.com", "password": "wrong"},
|
||||
)
|
||||
|
||||
# 11th attempt should be rate-limited
|
||||
response = await password_client.post(
|
||||
"/auth/login",
|
||||
json={"email": "fake@test.com", "password": "wrong"},
|
||||
)
|
||||
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_create_and_verify():
|
||||
from reflector.auth.auth_password import _create_access_token, _verify_token
|
||||
|
||||
token, expires_in = _create_access_token("user-123", "test@example.com")
|
||||
assert expires_in > 0
|
||||
|
||||
payload = _verify_token(token)
|
||||
assert payload["sub"] == "user-123"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert "exp" in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_with_jwt():
|
||||
from reflector.auth.auth_password import (
|
||||
_authenticate_user,
|
||||
_create_access_token,
|
||||
)
|
||||
|
||||
token, _ = _create_access_token("user-abc", "abc@test.com")
|
||||
user = await _authenticate_user(token, None)
|
||||
|
||||
assert user is not None
|
||||
assert user.sub == "user-abc"
|
||||
assert user.email == "abc@test.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_invalid_jwt():
|
||||
from fastapi import HTTPException
|
||||
|
||||
from reflector.auth.auth_password import _authenticate_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _authenticate_user("invalid.jwt.token", None)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_no_credentials():
|
||||
from reflector.auth.auth_password import _authenticate_user
|
||||
|
||||
user = await _authenticate_user(None, None)
|
||||
assert user is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_user_raises_without_token():
|
||||
"""Verify that current_user dependency raises 401 without token."""
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from reflector.auth import auth_password
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
async def test_endpoint(user=Depends(auth_password.current_user)):
|
||||
return {"user": user.sub}
|
||||
|
||||
# Use sync TestClient for simplicity
|
||||
client = TestClient(app)
|
||||
response = client.get("/test")
|
||||
# OAuth2PasswordBearer with auto_error=False returns None, then current_user raises 401
|
||||
assert response.status_code == 401
|
||||
97
server/tests/test_create_admin.py
Normal file
97
server/tests/test_create_admin.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Tests for admin user creation logic (used by create_admin CLI tool)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.auth.password_utils import hash_password, verify_password
|
||||
from reflector.db.users import user_controller
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
|
||||
async def _provision_admin(email: str, password: str):
|
||||
"""Mirrors the logic in create_admin.create_admin() without managing DB connections."""
|
||||
password_hash = hash_password(password)
|
||||
|
||||
existing = await user_controller.get_by_email(email)
|
||||
if existing:
|
||||
await user_controller.set_password_hash(existing.id, password_hash)
|
||||
else:
|
||||
await user_controller.create_or_update(
|
||||
id=generate_uuid4(),
|
||||
authentik_uid=f"local:{email}",
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_admin_new_user(setup_database):
|
||||
await _provision_admin("newadmin@test.com", "password123")
|
||||
|
||||
user = await user_controller.get_by_email("newadmin@test.com")
|
||||
assert user is not None
|
||||
assert user.email == "newadmin@test.com"
|
||||
assert user.authentik_uid == "local:newadmin@test.com"
|
||||
assert user.password_hash is not None
|
||||
assert verify_password("password123", user.password_hash)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_admin_updates_existing(setup_database):
|
||||
# Create first
|
||||
await _provision_admin("admin@test.com", "oldpassword")
|
||||
user1 = await user_controller.get_by_email("admin@test.com")
|
||||
|
||||
# Update password
|
||||
await _provision_admin("admin@test.com", "newpassword")
|
||||
user2 = await user_controller.get_by_email("admin@test.com")
|
||||
|
||||
assert user1.id == user2.id # same user, not duplicated
|
||||
assert verify_password("newpassword", user2.password_hash)
|
||||
assert not verify_password("oldpassword", user2.password_hash)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_admin_idempotent(setup_database):
|
||||
await _provision_admin("admin@test.com", "samepassword")
|
||||
await _provision_admin("admin@test.com", "samepassword")
|
||||
|
||||
# Should only have one user
|
||||
users = await user_controller.list_all()
|
||||
admin_users = [u for u in users if u.email == "admin@test.com"]
|
||||
assert len(admin_users) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_or_update_with_password_hash(setup_database):
|
||||
"""Test the extended create_or_update method with password_hash parameter."""
|
||||
pw_hash = hash_password("test123")
|
||||
user = await user_controller.create_or_update(
|
||||
id=generate_uuid4(),
|
||||
authentik_uid="local:test@example.com",
|
||||
email="test@example.com",
|
||||
password_hash=pw_hash,
|
||||
)
|
||||
|
||||
assert user.password_hash == pw_hash
|
||||
|
||||
fetched = await user_controller.get_by_email("test@example.com")
|
||||
assert fetched is not None
|
||||
assert verify_password("test123", fetched.password_hash)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_password_hash(setup_database):
|
||||
"""Test the set_password_hash method."""
|
||||
user = await user_controller.create_or_update(
|
||||
id=generate_uuid4(),
|
||||
authentik_uid="local:pw@test.com",
|
||||
email="pw@test.com",
|
||||
)
|
||||
assert user.password_hash is None
|
||||
|
||||
pw_hash = hash_password("newpass")
|
||||
await user_controller.set_password_hash(user.id, pw_hash)
|
||||
|
||||
updated = await user_controller.get_by_email("pw@test.com")
|
||||
assert updated is not None
|
||||
assert verify_password("newpass", updated.password_hash)
|
||||
58
server/tests/test_password_utils.py
Normal file
58
server/tests/test_password_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for password hashing utilities."""
|
||||
|
||||
from reflector.auth.password_utils import hash_password, verify_password
|
||||
|
||||
|
||||
def test_hash_and_verify():
|
||||
pw = "my-secret-password"
|
||||
h = hash_password(pw)
|
||||
assert verify_password(pw, h) is True
|
||||
|
||||
|
||||
def test_wrong_password():
|
||||
h = hash_password("correct")
|
||||
assert verify_password("wrong", h) is False
|
||||
|
||||
|
||||
def test_hash_format():
|
||||
h = hash_password("test")
|
||||
parts = h.split("$")
|
||||
assert len(parts) == 3
|
||||
assert parts[0] == "pbkdf2:sha256:100000"
|
||||
assert len(parts[1]) == 32 # 16 bytes hex = 32 chars
|
||||
assert len(parts[2]) == 64 # sha256 hex = 64 chars
|
||||
|
||||
|
||||
def test_different_salts():
|
||||
h1 = hash_password("same")
|
||||
h2 = hash_password("same")
|
||||
assert h1 != h2 # different salts produce different hashes
|
||||
assert verify_password("same", h1) is True
|
||||
assert verify_password("same", h2) is True
|
||||
|
||||
|
||||
def test_malformed_hash():
|
||||
assert verify_password("test", "garbage") is False
|
||||
assert verify_password("test", "") is False
|
||||
assert verify_password("test", "pbkdf2:sha256:100000$short") is False
|
||||
|
||||
|
||||
def test_empty_password():
|
||||
h = hash_password("")
|
||||
assert verify_password("", h) is True
|
||||
assert verify_password("notempty", h) is False
|
||||
|
||||
|
||||
def test_unicode_password():
|
||||
pw = "p\u00e4ssw\u00f6rd\U0001f512"
|
||||
h = hash_password(pw)
|
||||
assert verify_password(pw, h) is True
|
||||
assert verify_password("password", h) is False
|
||||
|
||||
|
||||
def test_constant_time_comparison():
|
||||
"""Verify that hmac.compare_digest is used (structural test)."""
|
||||
import inspect
|
||||
|
||||
source = inspect.getsource(verify_password)
|
||||
assert "hmac.compare_digest" in source
|
||||
Reference in New Issue
Block a user