mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-03-21 22:56:47 +00:00
refactor(auth): consolidate PUBLIC_MODE and mutation guards into reusable helpers (#909)
* refactor(auth): consolidate PUBLIC_MODE and mutation guards into reusable helpers * fix: fix websocket test override
This commit is contained in:
committed by
GitHub
parent
cf6e867cf1
commit
4ae56b730a
@@ -12,6 +12,7 @@ AccessTokenInfo = auth_module.AccessTokenInfo
|
|||||||
authenticated = auth_module.authenticated
|
authenticated = auth_module.authenticated
|
||||||
current_user = auth_module.current_user
|
current_user = auth_module.current_user
|
||||||
current_user_optional = auth_module.current_user_optional
|
current_user_optional = auth_module.current_user_optional
|
||||||
|
current_user_optional_if_public_mode = auth_module.current_user_optional_if_public_mode
|
||||||
parse_ws_bearer_token = auth_module.parse_ws_bearer_token
|
parse_ws_bearer_token = auth_module.parse_ws_bearer_token
|
||||||
current_user_ws_optional = auth_module.current_user_ws_optional
|
current_user_ws_optional = auth_module.current_user_ws_optional
|
||||||
verify_raw_token = auth_module.verify_raw_token
|
verify_raw_token = auth_module.verify_raw_token
|
||||||
|
|||||||
@@ -129,6 +129,17 @@ async def current_user_optional(
|
|||||||
return await _authenticate_user(jwt_token, api_key, jwtauth)
|
return await _authenticate_user(jwt_token, api_key, jwtauth)
|
||||||
|
|
||||||
|
|
||||||
|
async def current_user_optional_if_public_mode(
|
||||||
|
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||||
|
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||||
|
jwtauth: JWTAuth = Depends(),
|
||||||
|
) -> Optional[UserInfo]:
|
||||||
|
user = await _authenticate_user(jwt_token, api_key, jwtauth)
|
||||||
|
if user is None and not settings.PUBLIC_MODE:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
def parse_ws_bearer_token(
|
def parse_ws_bearer_token(
|
||||||
websocket: "WebSocket",
|
websocket: "WebSocket",
|
||||||
) -> tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
|
|||||||
@@ -21,6 +21,11 @@ def current_user_optional():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def current_user_optional_if_public_mode():
|
||||||
|
# auth_none means no authentication at all — always public
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_ws_bearer_token(websocket):
|
def parse_ws_bearer_token(websocket):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|||||||
@@ -150,6 +150,16 @@ async def current_user_optional(
|
|||||||
return await _authenticate_user(jwt_token, api_key)
|
return await _authenticate_user(jwt_token, api_key)
|
||||||
|
|
||||||
|
|
||||||
|
async def current_user_optional_if_public_mode(
|
||||||
|
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||||
|
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||||
|
) -> Optional[UserInfo]:
|
||||||
|
user = await _authenticate_user(jwt_token, api_key)
|
||||||
|
if user is None and not settings.PUBLIC_MODE:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
# --- WebSocket auth (same pattern as auth_jwt.py) ---
|
# --- WebSocket auth (same pattern as auth_jwt.py) ---
|
||||||
def parse_ws_bearer_token(
|
def parse_ws_bearer_token(
|
||||||
websocket: "WebSocket",
|
websocket: "WebSocket",
|
||||||
|
|||||||
@@ -697,6 +697,18 @@ class TranscriptController:
|
|||||||
return False
|
return False
|
||||||
return user_id and transcript.user_id == user_id
|
return user_id and transcript.user_id == user_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_can_mutate(transcript: Transcript, user_id: str | None) -> None:
|
||||||
|
"""
|
||||||
|
Raises HTTP 403 if the user cannot mutate the transcript.
|
||||||
|
|
||||||
|
Policy:
|
||||||
|
- Anonymous transcripts (user_id is None) are editable by anyone
|
||||||
|
- Owned transcripts can only be mutated by their owner
|
||||||
|
"""
|
||||||
|
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def transaction(self):
|
async def transaction(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from reflector.db.meetings import (
|
|||||||
)
|
)
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.string import NonEmptyString
|
from reflector.utils.string import NonEmptyString
|
||||||
from reflector.video_platforms.factory import create_platform_client
|
from reflector.video_platforms.factory import create_platform_client
|
||||||
|
|
||||||
@@ -92,15 +91,15 @@ class StartRecordingRequest(BaseModel):
|
|||||||
async def start_recording(
|
async def start_recording(
|
||||||
meeting_id: NonEmptyString,
|
meeting_id: NonEmptyString,
|
||||||
body: StartRecordingRequest,
|
body: StartRecordingRequest,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[
|
||||||
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
|
],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Start cloud or raw-tracks recording via Daily.co REST API.
|
"""Start cloud or raw-tracks recording via Daily.co REST API.
|
||||||
|
|
||||||
Both cloud and raw-tracks are started via REST API to bypass enable_recording limitation of allowing only 1 recording at a time.
|
Both cloud and raw-tracks are started via REST API to bypass enable_recording limitation of allowing only 1 recording at a time.
|
||||||
Uses different instanceIds for cloud vs raw-tracks (same won't work)
|
Uses different instanceIds for cloud vs raw-tracks (same won't work)
|
||||||
"""
|
"""
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||||
if not meeting:
|
if not meeting:
|
||||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from reflector.db.rooms import rooms_controller
|
|||||||
from reflector.redis_cache import RedisAsyncLock
|
from reflector.redis_cache import RedisAsyncLock
|
||||||
from reflector.schemas.platform import Platform
|
from reflector.schemas.platform import Platform
|
||||||
from reflector.services.ics_sync import ics_sync_service
|
from reflector.services.ics_sync import ics_sync_service
|
||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.url import add_query_param
|
from reflector.utils.url import add_query_param
|
||||||
from reflector.video_platforms.factory import create_platform_client
|
from reflector.video_platforms.factory import create_platform_client
|
||||||
from reflector.worker.webhook import test_webhook
|
from reflector.worker.webhook import test_webhook
|
||||||
@@ -178,11 +177,10 @@ router = APIRouter()
|
|||||||
|
|
||||||
@router.get("/rooms", response_model=Page[RoomDetails])
|
@router.get("/rooms", response_model=Page[RoomDetails])
|
||||||
async def rooms_list(
|
async def rooms_list(
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[
|
||||||
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
|
],
|
||||||
) -> list[RoomDetails]:
|
) -> list[RoomDetails]:
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
paginated = await apaginate(
|
paginated = await apaginate(
|
||||||
|
|||||||
@@ -263,16 +263,15 @@ class SearchResponse(BaseModel):
|
|||||||
|
|
||||||
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
|
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
|
||||||
async def transcripts_list(
|
async def transcripts_list(
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[
|
||||||
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
|
],
|
||||||
source_kind: SourceKind | None = None,
|
source_kind: SourceKind | None = None,
|
||||||
room_id: str | None = None,
|
room_id: str | None = None,
|
||||||
search_term: str | None = None,
|
search_term: str | None = None,
|
||||||
change_seq_from: int | None = None,
|
change_seq_from: int | None = None,
|
||||||
sort_by: Literal["created_at", "change_seq"] | None = None,
|
sort_by: Literal["created_at", "change_seq"] | None = None,
|
||||||
):
|
):
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
# Default behavior preserved: sort_by=None → "-created_at"
|
# Default behavior preserved: sort_by=None → "-created_at"
|
||||||
@@ -307,13 +306,10 @@ async def transcripts_search(
|
|||||||
from_datetime: SearchFromDatetimeParam = None,
|
from_datetime: SearchFromDatetimeParam = None,
|
||||||
to_datetime: SearchToDatetimeParam = None,
|
to_datetime: SearchToDatetimeParam = None,
|
||||||
user: Annotated[
|
user: Annotated[
|
||||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
] = None,
|
] = None,
|
||||||
):
|
):
|
||||||
"""Full-text search across transcript titles and content."""
|
"""Full-text search across transcript titles and content."""
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
if from_datetime and to_datetime and from_datetime > to_datetime:
|
if from_datetime and to_datetime and from_datetime > to_datetime:
|
||||||
@@ -346,11 +342,10 @@ async def transcripts_search(
|
|||||||
@router.post("/transcripts", response_model=GetTranscriptWithParticipants)
|
@router.post("/transcripts", response_model=GetTranscriptWithParticipants)
|
||||||
async def transcripts_create(
|
async def transcripts_create(
|
||||||
info: CreateTranscript,
|
info: CreateTranscript,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[
|
||||||
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
|
],
|
||||||
):
|
):
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
info.name,
|
info.name,
|
||||||
|
|||||||
@@ -62,8 +62,7 @@ async def transcript_add_participant(
|
|||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||||
raise HTTPException(status_code=403, detail="Not authorized")
|
|
||||||
|
|
||||||
# ensure the speaker is unique
|
# ensure the speaker is unique
|
||||||
if participant.speaker is not None and transcript.participants is not None:
|
if participant.speaker is not None and transcript.participants is not None:
|
||||||
@@ -109,8 +108,7 @@ async def transcript_update_participant(
|
|||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||||
raise HTTPException(status_code=403, detail="Not authorized")
|
|
||||||
|
|
||||||
# ensure the speaker is unique
|
# ensure the speaker is unique
|
||||||
for p in transcript.participants:
|
for p in transcript.participants:
|
||||||
@@ -148,7 +146,6 @@ async def transcript_delete_participant(
|
|||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||||
raise HTTPException(status_code=403, detail="Not authorized")
|
|
||||||
await transcripts_controller.delete_participant(transcript, participant_id)
|
await transcripts_controller.delete_participant(transcript, participant_id)
|
||||||
return DeletionStatus(status="ok")
|
return DeletionStatus(status="ok")
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from reflector.services.transcript_process import (
|
|||||||
prepare_transcript_processing,
|
prepare_transcript_processing,
|
||||||
validate_transcript_for_processing,
|
validate_transcript_for_processing,
|
||||||
)
|
)
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -27,11 +26,10 @@ class ProcessStatus(BaseModel):
|
|||||||
@router.post("/transcripts/{transcript_id}/process")
|
@router.post("/transcripts/{transcript_id}/process")
|
||||||
async def transcript_process(
|
async def transcript_process(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[
|
||||||
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
|
],
|
||||||
) -> ProcessStatus:
|
) -> ProcessStatus:
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
|
|||||||
@@ -41,8 +41,7 @@ async def transcript_assign_speaker(
|
|||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||||
raise HTTPException(status_code=403, detail="Not authorized")
|
|
||||||
|
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
@@ -121,8 +120,7 @@ async def transcript_merge_speaker(
|
|||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
if transcript.user_id is not None and transcript.user_id != user_id:
|
transcripts_controller.check_can_mutate(transcript, user_id)
|
||||||
raise HTTPException(status_code=403, detail="Not authorized")
|
|
||||||
|
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel
|
|||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -22,11 +21,10 @@ async def transcript_record_upload(
|
|||||||
chunk_number: int,
|
chunk_number: int,
|
||||||
total_chunks: int,
|
total_chunks: int,
|
||||||
chunk: UploadFile,
|
chunk: UploadFile,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[
|
||||||
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
|
],
|
||||||
):
|
):
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||||
|
|
||||||
@@ -16,11 +15,10 @@ async def transcript_record_webrtc(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
params: RtcOffer,
|
params: RtcOffer,
|
||||||
request: Request,
|
request: Request,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[
|
||||||
|
Optional[auth.UserInfo], Depends(auth.current_user_optional_if_public_mode)
|
||||||
|
],
|
||||||
):
|
):
|
||||||
if not user and not settings.PUBLIC_MODE:
|
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
|
|||||||
@@ -437,6 +437,8 @@ async def ws_manager_in_memory(monkeypatch):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None
|
fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None
|
||||||
|
# current_user_optional_if_public_mode is NOT overridden here so the real
|
||||||
|
# implementation runs and enforces the PUBLIC_MODE check during tests.
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -491,37 +493,39 @@ async def authenticated_client2():
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def authenticated_client_ctx():
|
async def authenticated_client_ctx():
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
from reflector.auth import current_user, current_user_optional
|
from reflector.auth import (
|
||||||
|
current_user,
|
||||||
|
current_user_optional,
|
||||||
|
current_user_optional_if_public_mode,
|
||||||
|
)
|
||||||
|
|
||||||
app.dependency_overrides[current_user] = lambda: {
|
_user = lambda: {"sub": "randomuserid", "email": "test@mail.com"}
|
||||||
"sub": "randomuserid",
|
app.dependency_overrides[current_user] = _user
|
||||||
"email": "test@mail.com",
|
app.dependency_overrides[current_user_optional] = _user
|
||||||
}
|
app.dependency_overrides[current_user_optional_if_public_mode] = _user
|
||||||
app.dependency_overrides[current_user_optional] = lambda: {
|
|
||||||
"sub": "randomuserid",
|
|
||||||
"email": "test@mail.com",
|
|
||||||
}
|
|
||||||
yield
|
yield
|
||||||
del app.dependency_overrides[current_user]
|
del app.dependency_overrides[current_user]
|
||||||
del app.dependency_overrides[current_user_optional]
|
del app.dependency_overrides[current_user_optional]
|
||||||
|
del app.dependency_overrides[current_user_optional_if_public_mode]
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def authenticated_client2_ctx():
|
async def authenticated_client2_ctx():
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
from reflector.auth import current_user, current_user_optional
|
from reflector.auth import (
|
||||||
|
current_user,
|
||||||
|
current_user_optional,
|
||||||
|
current_user_optional_if_public_mode,
|
||||||
|
)
|
||||||
|
|
||||||
app.dependency_overrides[current_user] = lambda: {
|
_user = lambda: {"sub": "randomuserid2", "email": "test@mail.com"}
|
||||||
"sub": "randomuserid2",
|
app.dependency_overrides[current_user] = _user
|
||||||
"email": "test@mail.com",
|
app.dependency_overrides[current_user_optional] = _user
|
||||||
}
|
app.dependency_overrides[current_user_optional_if_public_mode] = _user
|
||||||
app.dependency_overrides[current_user_optional] = lambda: {
|
|
||||||
"sub": "randomuserid2",
|
|
||||||
"email": "test@mail.com",
|
|
||||||
}
|
|
||||||
yield
|
yield
|
||||||
del app.dependency_overrides[current_user]
|
del app.dependency_overrides[current_user]
|
||||||
del app.dependency_overrides[current_user_optional]
|
del app.dependency_overrides[current_user_optional]
|
||||||
|
del app.dependency_overrides[current_user_optional_if_public_mode]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|||||||
@@ -141,33 +141,19 @@ async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user
|
|||||||
await asyncio.sleep(0.2)
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
# Emit an event to the user's room via a standard HTTP action
|
# Emit an event to the user's room via a standard HTTP action
|
||||||
|
# Use a real HTTP request to the server with the JWT token so that
|
||||||
|
# current_user_optional_if_public_mode is exercised without dependency overrides
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from reflector.app import app
|
async with AsyncClient(base_url=f"http://{host}:{port}/v1") as ac:
|
||||||
from reflector.auth import current_user, current_user_optional
|
resp = await ac.post(
|
||||||
|
"/transcripts",
|
||||||
# Override auth dependencies so HTTP request is performed as the same user
|
json={"name": "WS Test"},
|
||||||
# Use the internal user.id (not the Authentik UID)
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
app.dependency_overrides[current_user] = lambda: {
|
)
|
||||||
"sub": user.id,
|
|
||||||
"email": "user-abc@example.com",
|
|
||||||
}
|
|
||||||
app.dependency_overrides[current_user_optional] = lambda: {
|
|
||||||
"sub": user.id,
|
|
||||||
"email": "user-abc@example.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Use in-memory client (global singleton makes it share ws_manager)
|
|
||||||
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
|
|
||||||
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
|
|
||||||
resp = await ac.post("/transcripts", json={"name": "WS Test"})
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
# Receive the published event
|
# Receive the published event
|
||||||
msg = await ws.receive_json()
|
msg = await ws.receive_json()
|
||||||
assert msg["event"] == "TRANSCRIPT_CREATED"
|
assert msg["event"] == "TRANSCRIPT_CREATED"
|
||||||
assert "id" in msg["data"]
|
assert "id" in msg["data"]
|
||||||
|
|
||||||
# Clean overrides
|
|
||||||
del app.dependency_overrides[current_user]
|
|
||||||
del app.dependency_overrides[current_user_optional]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user