refactor(auth): consolidate PUBLIC_MODE and mutation guards into reusable helpers (#909)

* refactor(auth): consolidate PUBLIC_MODE and mutation guards into reusable helpers

* fix: fix websocket test override
This commit is contained in:
Juan Diego García
2026-03-12 10:51:26 -05:00
committed by GitHub
parent cf6e867cf1
commit 4ae56b730a
15 changed files with 96 additions and 86 deletions

View File

@@ -12,6 +12,7 @@ AccessTokenInfo = auth_module.AccessTokenInfo
authenticated = auth_module.authenticated authenticated = auth_module.authenticated
current_user = auth_module.current_user current_user = auth_module.current_user
current_user_optional = auth_module.current_user_optional current_user_optional = auth_module.current_user_optional
current_user_optional_if_public_mode = auth_module.current_user_optional_if_public_mode
parse_ws_bearer_token = auth_module.parse_ws_bearer_token parse_ws_bearer_token = auth_module.parse_ws_bearer_token
current_user_ws_optional = auth_module.current_user_ws_optional current_user_ws_optional = auth_module.current_user_ws_optional
verify_raw_token = auth_module.verify_raw_token verify_raw_token = auth_module.verify_raw_token

View File

@@ -129,6 +129,17 @@ async def current_user_optional(
return await _authenticate_user(jwt_token, api_key, jwtauth) return await _authenticate_user(jwt_token, api_key, jwtauth)
async def current_user_optional_if_public_mode(
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
api_key: Annotated[Optional[str], Depends(api_key_header)],
jwtauth: JWTAuth = Depends(),
) -> Optional[UserInfo]:
user = await _authenticate_user(jwt_token, api_key, jwtauth)
if user is None and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
return user
def parse_ws_bearer_token( def parse_ws_bearer_token(
websocket: "WebSocket", websocket: "WebSocket",
) -> tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:

View File

@@ -21,6 +21,11 @@ def current_user_optional():
return None return None
def current_user_optional_if_public_mode():
# auth_none means no authentication at all — always public
return None
def parse_ws_bearer_token(websocket): def parse_ws_bearer_token(websocket):
return None, None return None, None

View File

@@ -150,6 +150,16 @@ async def current_user_optional(
return await _authenticate_user(jwt_token, api_key) return await _authenticate_user(jwt_token, api_key)
async def current_user_optional_if_public_mode(
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
api_key: Annotated[Optional[str], Depends(api_key_header)],
) -> Optional[UserInfo]:
user = await _authenticate_user(jwt_token, api_key)
if user is None and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
return user
# --- WebSocket auth (same pattern as auth_jwt.py) --- # --- WebSocket auth (same pattern as auth_jwt.py) ---
def parse_ws_bearer_token( def parse_ws_bearer_token(
websocket: "WebSocket", websocket: "WebSocket",

View File

@@ -697,6 +697,18 @@ class TranscriptController:
return False return False
return user_id and transcript.user_id == user_id return user_id and transcript.user_id == user_id
@staticmethod
def check_can_mutate(transcript: Transcript, user_id: str | None) -> None:
"""
Raises HTTP 403 if the user cannot mutate the transcript.
Policy:
- Anonymous transcripts (user_id is None) are editable by anyone
- Owned transcripts can only be mutated by their owner
"""
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
""" """

View File

@@ -16,7 +16,6 @@ from reflector.db.meetings import (
) )
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings
from reflector.utils.string import NonEmptyString from reflector.utils.string import NonEmptyString
from reflector.video_platforms.factory import create_platform_client from reflector.video_platforms.factory import create_platform_client
@@ -92,15 +91,15 @@ class StartRecordingRequest(BaseModel):
async def start_recording( async def start_recording(
meeting_id: NonEmptyString, meeting_id: NonEmptyString,
body: StartRecordingRequest, body: StartRecordingRequest,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Start cloud or raw-tracks recording via Daily.co REST API. """Start cloud or raw-tracks recording via Daily.co REST API.
Both cloud and raw-tracks are started via REST API to bypass enable_recording limitation of allowing only 1 recording at a time. Both cloud and raw-tracks are started via REST API to bypass enable_recording limitation of allowing only 1 recording at a time.
Uses different instanceIds for cloud vs raw-tracks (same won't work) Uses different instanceIds for cloud vs raw-tracks (same won't work)
""" """
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
meeting = await meetings_controller.get_by_id(meeting_id) meeting = await meetings_controller.get_by_id(meeting_id)
if not meeting: if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")

View File

@@ -17,7 +17,6 @@ from reflector.db.rooms import rooms_controller
from reflector.redis_cache import RedisAsyncLock from reflector.redis_cache import RedisAsyncLock
from reflector.schemas.platform import Platform from reflector.schemas.platform import Platform
from reflector.services.ics_sync import ics_sync_service from reflector.services.ics_sync import ics_sync_service
from reflector.settings import settings
from reflector.utils.url import add_query_param from reflector.utils.url import add_query_param
from reflector.video_platforms.factory import create_platform_client from reflector.video_platforms.factory import create_platform_client
from reflector.worker.webhook import test_webhook from reflector.worker.webhook import test_webhook
@@ -178,11 +177,10 @@ router = APIRouter()
@router.get("/rooms", response_model=Page[RoomDetails]) @router.get("/rooms", response_model=Page[RoomDetails])
async def rooms_list( async def rooms_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
],
) -> list[RoomDetails]: ) -> list[RoomDetails]:
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
paginated = await apaginate( paginated = await apaginate(

View File

@@ -263,16 +263,15 @@ class SearchResponse(BaseModel):
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal]) @router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
async def transcripts_list( async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
],
source_kind: SourceKind | None = None, source_kind: SourceKind | None = None,
room_id: str | None = None, room_id: str | None = None,
search_term: str | None = None, search_term: str | None = None,
change_seq_from: int | None = None, change_seq_from: int | None = None,
sort_by: Literal["created_at", "change_seq"] | None = None, sort_by: Literal["created_at", "change_seq"] | None = None,
): ):
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
# Default behavior preserved: sort_by=None → "-created_at" # Default behavior preserved: sort_by=None → "-created_at"
@@ -307,13 +306,10 @@ async def transcripts_search(
from_datetime: SearchFromDatetimeParam = None, from_datetime: SearchFromDatetimeParam = None,
to_datetime: SearchToDatetimeParam = None, to_datetime: SearchToDatetimeParam = None,
user: Annotated[ user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional) Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
] = None, ] = None,
): ):
"""Full-text search across transcript titles and content.""" """Full-text search across transcript titles and content."""
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
if from_datetime and to_datetime and from_datetime > to_datetime: if from_datetime and to_datetime and from_datetime > to_datetime:
@@ -346,11 +342,10 @@ async def transcripts_search(
@router.post("/transcripts", response_model=GetTranscriptWithParticipants) @router.post("/transcripts", response_model=GetTranscriptWithParticipants)
async def transcripts_create( async def transcripts_create(
info: CreateTranscript, info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
],
): ):
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.add( transcript = await transcripts_controller.add(
info.name, info.name,

View File

@@ -62,8 +62,7 @@ async def transcript_add_participant(
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id: transcripts_controller.check_can_mutate(transcript, user_id)
raise HTTPException(status_code=403, detail="Not authorized")
# ensure the speaker is unique # ensure the speaker is unique
if participant.speaker is not None and transcript.participants is not None: if participant.speaker is not None and transcript.participants is not None:
@@ -109,8 +108,7 @@ async def transcript_update_participant(
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id: transcripts_controller.check_can_mutate(transcript, user_id)
raise HTTPException(status_code=403, detail="Not authorized")
# ensure the speaker is unique # ensure the speaker is unique
for p in transcript.participants: for p in transcript.participants:
@@ -148,7 +146,6 @@ async def transcript_delete_participant(
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id: transcripts_controller.check_can_mutate(transcript, user_id)
raise HTTPException(status_code=403, detail="Not authorized")
await transcripts_controller.delete_participant(transcript, participant_id) await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")

View File

@@ -15,7 +15,6 @@ from reflector.services.transcript_process import (
prepare_transcript_processing, prepare_transcript_processing,
validate_transcript_for_processing, validate_transcript_for_processing,
) )
from reflector.settings import settings
router = APIRouter() router = APIRouter()
@@ -27,11 +26,10 @@ class ProcessStatus(BaseModel):
@router.post("/transcripts/{transcript_id}/process") @router.post("/transcripts/{transcript_id}/process")
async def transcript_process( async def transcript_process(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
],
) -> ProcessStatus: ) -> ProcessStatus:
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id

View File

@@ -41,8 +41,7 @@ async def transcript_assign_speaker(
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id: transcripts_controller.check_can_mutate(transcript, user_id)
raise HTTPException(status_code=403, detail="Not authorized")
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
@@ -121,8 +120,7 @@ async def transcript_merge_speaker(
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id: transcripts_controller.check_can_mutate(transcript, user_id)
raise HTTPException(status_code=403, detail="Not authorized")
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel
import reflector.auth as auth import reflector.auth as auth
from reflector.db.transcripts import SourceKind, transcripts_controller from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.settings import settings
router = APIRouter() router = APIRouter()
@@ -22,11 +21,10 @@ async def transcript_record_upload(
chunk_number: int, chunk_number: int,
total_chunks: int, total_chunks: int,
chunk: UploadFile, chunk: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
],
): ):
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id

View File

@@ -4,7 +4,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request
import reflector.auth as auth import reflector.auth as auth
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.settings import settings
from .rtc_offer import RtcOffer, rtc_offer_base from .rtc_offer import RtcOffer, rtc_offer_base
@@ -16,11 +15,10 @@ async def transcript_record_webrtc(
transcript_id: str, transcript_id: str,
params: RtcOffer, params: RtcOffer,
request: Request, request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
],
): ):
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id transcript_id, user_id=user_id

View File

@@ -437,6 +437,8 @@ async def ws_manager_in_memory(monkeypatch):
try: try:
fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None
# current_user_optional_if_public_mode is NOT overridden here so the real
# implementation runs and enforces the PUBLIC_MODE check during tests.
except Exception: except Exception:
pass pass
@@ -491,37 +493,39 @@ async def authenticated_client2():
@asynccontextmanager @asynccontextmanager
async def authenticated_client_ctx(): async def authenticated_client_ctx():
from reflector.app import app from reflector.app import app
from reflector.auth import current_user, current_user_optional from reflector.auth import (
current_user,
current_user_optional,
current_user_optional_if_public_mode,
)
app.dependency_overrides[current_user] = lambda: { _user = lambda: {"sub": "randomuserid", "email": "test@mail.com"}
"sub": "randomuserid", app.dependency_overrides[current_user] = _user
"email": "test@mail.com", app.dependency_overrides[current_user_optional] = _user
} app.dependency_overrides[current_user_optional_if_public_mode] = _user
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid",
"email": "test@mail.com",
}
yield yield
del app.dependency_overrides[current_user] del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional] del app.dependency_overrides[current_user_optional]
del app.dependency_overrides[current_user_optional_if_public_mode]
@asynccontextmanager @asynccontextmanager
async def authenticated_client2_ctx(): async def authenticated_client2_ctx():
from reflector.app import app from reflector.app import app
from reflector.auth import current_user, current_user_optional from reflector.auth import (
current_user,
current_user_optional,
current_user_optional_if_public_mode,
)
app.dependency_overrides[current_user] = lambda: { _user = lambda: {"sub": "randomuserid2", "email": "test@mail.com"}
"sub": "randomuserid2", app.dependency_overrides[current_user] = _user
"email": "test@mail.com", app.dependency_overrides[current_user_optional] = _user
} app.dependency_overrides[current_user_optional_if_public_mode] = _user
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
yield yield
del app.dependency_overrides[current_user] del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional] del app.dependency_overrides[current_user_optional]
del app.dependency_overrides[current_user_optional_if_public_mode]
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@@ -141,33 +141,19 @@ async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
# Emit an event to the user's room via a standard HTTP action # Emit an event to the user's room via a standard HTTP action
# Use a real HTTP request to the server with the JWT token so that
# current_user_optional_if_public_mode is exercised without dependency overrides
from httpx import AsyncClient from httpx import AsyncClient
from reflector.app import app async with AsyncClient(base_url=f"http://{host}:{port}/v1") as ac:
from reflector.auth import current_user, current_user_optional resp = await ac.post(
"/transcripts",
# Override auth dependencies so HTTP request is performed as the same user json={"name": "WS Test"},
# Use the internal user.id (not the Authentik UID) headers={"Authorization": f"Bearer {token}"},
app.dependency_overrides[current_user] = lambda: { )
"sub": user.id,
"email": "user-abc@example.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": user.id,
"email": "user-abc@example.com",
}
# Use in-memory client (global singleton makes it share ws_manager)
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
resp = await ac.post("/transcripts", json={"name": "WS Test"})
assert resp.status_code == 200 assert resp.status_code == 200
# Receive the published event # Receive the published event
msg = await ws.receive_json() msg = await ws.receive_json()
assert msg["event"] == "TRANSCRIPT_CREATED" assert msg["event"] == "TRANSCRIPT_CREATED"
assert "id" in msg["data"] assert "id" in msg["data"]
# Clean overrides
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]