diff --git a/.env.example b/.env.example index e972e7b4..8bf1758a 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,8 @@ AUTH0_AUDIENCE=https://audience.com/api # JWT secret key: used to provide some protection around registration # Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))" JWT_SECRET_KEY=secret-key +# Note the list syntax pydantic-settings uses +ADMIN_ROLES='["Admin", "GalaxyAdmin"]' # Comma-separated list of allowed origins. Note we # don't process this with pydantic-settings as it needs # to be used before the FastAPI app loads diff --git a/auth/config.py b/auth/config.py index f9682bca..26f6e50b 100644 --- a/auth/config.py +++ b/auth/config.py @@ -11,6 +11,7 @@ class Settings(BaseSettings): auth0_audience: str jwt_secret_key: str auth0_algorithms: list[str] = ["RS256"] + admin_roles: list[str] = [] # Note we process this separately in app startup as it needs # to be available before the app starts cors_allowed_origins: str diff --git a/auth/validator.py b/auth/validator.py index 46ebacff..e8a6bac6 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -1,10 +1,12 @@ +from typing import Annotated + import httpx -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import jwk, jwt from jose.exceptions import JWTError -from auth.config import Settings +from auth.config import Settings, get_settings from schemas.tokens import AccessTokenPayload from schemas.user import User @@ -33,20 +35,12 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: except JWTError as e: raise HTTPException(status_code=401, detail=f"Invalid token: {e}") - roles_claim = "biocommons.org.au/roles" + roles_claim = "https://biocommons.org.au/roles" if roles_claim not in payload: raise HTTPException( status_code=403, detail=f"Missing required claim: {roles_claim}" ) - roles = payload[roles_claim] - if not isinstance(roles, list) or not any( - "admin" in role.lower() for role in roles - ): - raise HTTPException( - status_code=403, detail="Access denied: Insufficient permissions" - ) - return AccessTokenPayload(**payload) @@ -63,6 +57,17 @@ def get_rsa_key(token: str, settings: Settings) -> jwk.RSAKey | None: # type: i return None -def get_current_user(token: str = Depends(oauth2_scheme)) -> User: - access_token = verify_jwt(token) +def get_current_user(token: str = Depends(oauth2_scheme), + settings: Settings = Depends(get_settings)) -> User: + access_token = verify_jwt(token, settings=settings) return User(access_token=access_token) + + +def user_is_admin(current_user: Annotated[User, Depends(get_current_user)], + settings: Annotated[Settings, Depends(get_settings)]) -> User: + if not current_user.is_admin(settings=settings): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You must be an admin to access this endpoint." + ) + return current_user diff --git a/auth0/__init__.py b/auth0/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/auth0/client.py b/auth0/client.py new file mode 100644 index 00000000..902244b6 --- /dev/null +++ b/auth0/client.py @@ -0,0 +1,17 @@ +__all__ = ["Auth0Client"] + +import httpx + +from auth0.schemas import Auth0UserResponse + + +class Auth0Client: + + def __init__(self, domain: str): + self.domain = domain + + def get_users(self, access_token: str) -> list[Auth0UserResponse]: + url = f"https://{self.domain}/api/v2/users" + headers = {"Authorization": f"Bearer {access_token}"} + resp = httpx.get(url, headers=headers) + return resp.json() diff --git a/auth0/schemas.py b/auth0/schemas.py new file mode 100644 index 00000000..a0313c79 --- /dev/null +++ b/auth0/schemas.py @@ -0,0 +1,34 @@ +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict, EmailStr + + +class Auth0UserResponse(BaseModel): + """ + Response returned by Auth0's /users endpoint. + Note we have our own Auth0User model + that includes specifying the metadata fields we use. + """ + user_id: str + email: EmailStr + email_verified: bool + username: Optional[str] = None + phone_number: Optional[str] = None + phone_verified: Optional[bool] = None + created_at: datetime + updated_at: datetime + identities: List[dict] + app_metadata: Optional[dict] = None + user_metadata: Optional[dict] = None + picture: Optional[str] = None + name: Optional[str] = None + nickname: Optional[str] = None + last_ip: Optional[str] = None + last_login: Optional[datetime] = None + logins_count: Optional[int] = None + blocked: Optional[bool] = None + given_name: Optional[str] = None + family_name: Optional[str] = None + + model_config = ConfigDict(extra="allow") diff --git a/main.py b/main.py index a9d29a0c..6156d1e6 100644 --- a/main.py +++ b/main.py @@ -1,18 +1,17 @@ -import os -from dotenv import load_dotenv +from dotenv import dotenv_values from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware -from routers import bpa_register, galaxy_register, user +from routers import admin, bpa_register, galaxy_register, user # Load .env to get CORS_ALLOWED_ORIGINS. # Note that for most env variables, we use pydantic-settings # and load them via auth.config. But we need the # allowed_origins before we load the app -load_dotenv() +env_values = dotenv_values(".env") ALLOWED_ORIGINS = [ - origin.strip() for origin in os.getenv("CORS_ALLOWED_ORIGINS", "").split(",") + origin.strip() for origin in env_values.get("CORS_ALLOWED_ORIGINS", "").split(",") ] app = FastAPI() @@ -30,6 +29,7 @@ def public_route(): return {"message": "AAI Backend API"} +app.include_router(admin.router) app.include_router(user.router) app.include_router(bpa_register.router) app.include_router(galaxy_register.router) diff --git a/routers/admin.py b/routers/admin.py new file mode 100644 index 00000000..5381e659 --- /dev/null +++ b/routers/admin.py @@ -0,0 +1,23 @@ +from fastapi import APIRouter, Depends + +from auth.config import Settings, get_settings +from auth.management import get_management_token +from auth.validator import user_is_admin +from auth0.client import Auth0Client +from auth0.schemas import Auth0UserResponse + +router = APIRouter(prefix="/admin", tags=["admin"], + dependencies=[Depends(user_is_admin)]) + + +def get_auth0_client(settings: Settings = Depends(get_settings)): + return Auth0Client(settings.auth0_domain) + + +@router.get("/users", + response_model=list[Auth0UserResponse]) +def get_users(settings: Settings = Depends(get_settings), + client: Auth0Client = Depends(get_auth0_client)): + token = get_management_token(settings=settings) + resp = client.get_users(token) + return resp diff --git a/schemas/tokens.py b/schemas/tokens.py index a605580f..35cee689 100644 --- a/schemas/tokens.py +++ b/schemas/tokens.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class AccessTokenPayload(BaseModel): @@ -9,7 +9,7 @@ class AccessTokenPayload(BaseModel): """ biocommons_roles: list[str] = Field( - alias="biocommons.org.au/roles", + alias="https://biocommons.org.au/roles", description="BioCommons-specific roles assigned to the user", ) email: Optional[str] = Field(None, description="Email address") @@ -20,3 +20,6 @@ class AccessTokenPayload(BaseModel): iat: int = Field(description="Issued at time (as Unix timestamp)") azp: Optional[str] = Field(None, description="Authorized party") permissions: list[str] = Field(description="Permissions granted to the user") + + # Set populate_by_name so we can specify biocommons_roles as an argument + model_config = ConfigDict(populate_by_name=True) diff --git a/schemas/user.py b/schemas/user.py index ec6434fe..68f24e3e 100644 --- a/schemas/user.py +++ b/schemas/user.py @@ -1,5 +1,7 @@ from pydantic import BaseModel +from auth.config import Settings + from .tokens import AccessTokenPayload @@ -12,13 +14,13 @@ class User(BaseModel): access_token: AccessTokenPayload - def is_admin(self) -> bool: + def is_admin(self, settings: Settings) -> bool: """ Checks if the user has an admin role. """ # TODO: Need to finalize exactly what roles make # a user an admin for role in self.access_token.biocommons_roles: - if "admin" in role.lower(): + if role in settings.admin_roles: return True return False diff --git a/tests/conftest.py b/tests/conftest.py index a4689e9b..7d589985 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,20 @@ from fastapi.testclient import TestClient from auth.config import Settings, get_settings +from auth.validator import get_current_user from main import app +from tests.datagen import AccessTokenPayloadFactory, UserFactory + + +@pytest.fixture(autouse=True) +def ignore_env_file(): + """ + Always ignore the .env file when running tests, + so we get the same behaviour when the .env file is present or not. + """ + def get_settings_no_env_file(): + return Settings(_env_file=None) + app.dependency_overrides[get_settings] = get_settings_no_env_file @pytest.fixture @@ -15,12 +28,15 @@ def mock_settings(): auth0_audience="mock-audience", jwt_secret_key="mock-secret-key", cors_allowed_origins="https://test", + admin_roles=["Admin"], auth0_algorithms=["HS256"] ) - @pytest.fixture -def client_with_settings_override(mock_settings): +def test_client(mock_settings): + """ + Override the get_settings dependency to return a mocked Settings object. + """ # Define override def override_settings(): return mock_settings @@ -34,3 +50,18 @@ def override_settings(): # Reset override app.dependency_overrides.clear() + + +@pytest.fixture +def as_admin_user(): + """ + Override the get_current_user dependency to return a User object with admin role, + so admin check will pass. + """ + def override_user(): + token = AccessTokenPayloadFactory.build(biocommons_roles=["Admin"]) + return UserFactory.build(access_token=token) + + app.dependency_overrides[get_current_user] = override_user + yield + app.dependency_overrides.clear() diff --git a/tests/datagen.py b/tests/datagen.py index fd283ff2..87e965a7 100644 --- a/tests/datagen.py +++ b/tests/datagen.py @@ -1,6 +1,7 @@ from polyfactory.decorators import post_generated from polyfactory.factories.pydantic_factory import ModelFactory +from auth0.schemas import Auth0UserResponse from routers.bpa_register import BPARegistrationRequest from schemas.galaxy import GalaxyRegistrationData from schemas.service import Auth0User @@ -11,6 +12,9 @@ class AccessTokenPayloadFactory(ModelFactory[AccessTokenPayload]): ... +class Auth0UserResponseFactory(ModelFactory[Auth0UserResponse]): ... + + class UserFactory(ModelFactory[User]): ... diff --git a/tests/test_admin.py b/tests/test_admin.py new file mode 100644 index 00000000..90041b7e --- /dev/null +++ b/tests/test_admin.py @@ -0,0 +1,46 @@ +import pytest +from fastapi import HTTPException + +from auth.validator import get_current_user, user_is_admin +from main import app +from tests.datagen import ( + AccessTokenPayloadFactory, + Auth0UserResponseFactory, + UserFactory, +) + + +def test_get_users_requires_admin_unauthorized(test_client, mocker): + def get_nonadmin_user(): + payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"]) + return UserFactory.build(access_token=payload) + + app.dependency_overrides[get_current_user] = get_nonadmin_user + mocker.patch("routers.admin.get_management_token", return_value="mock_token") + resp = test_client.get("/admin/users") + assert resp.status_code == 403 + assert resp.json() == {"detail": "You must be an admin to access this endpoint."} + app.dependency_overrides.clear() + + +def test_user_is_admin(mock_settings): + payload = AccessTokenPayloadFactory.build(biocommons_roles=["Admin"]) + admin_user = UserFactory.build(access_token=payload) + assert user_is_admin(current_user=admin_user, settings=mock_settings) + + +def test_user_is_admin_nonadmin_user(mock_settings): + payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"]) + user = UserFactory.build(access_token=payload) + with pytest.raises(HTTPException, match="You must be an admin to access this endpoint."): + user_is_admin(current_user=user, settings=mock_settings) + + +def test_get_users(mocker, test_client, as_admin_user): + mocker.patch("routers.admin.get_management_token", return_value="mock_token") + mock_client = mocker.patch("routers.admin.Auth0Client") + users = Auth0UserResponseFactory.batch(3) + mock_client().get_users.return_value = users + resp = test_client.get("/admin/users") + assert resp.status_code == 200 + assert len(resp.json()) == 3 diff --git a/tests/test_bpa_register.py b/tests/test_bpa_register.py index a494cf5a..08c15ad1 100644 --- a/tests/test_bpa_register.py +++ b/tests/test_bpa_register.py @@ -31,7 +31,7 @@ def mock_auth_token(mocker): def test_successful_registration( - client_with_settings_override, mock_auth_token, mocker, valid_registration_data + test_client, mock_auth_token, mocker, valid_registration_data ): """Test successful user registration with BPA service""" mock_response = MagicMock() @@ -40,7 +40,7 @@ def test_successful_registration( mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) - response = client_with_settings_override.post( + response = test_client.post( "/bpa/register", json=valid_registration_data ) @@ -66,14 +66,14 @@ def test_successful_registration( def test_registration_duplicate_user( - client_with_settings_override, mock_auth_token, mocker, valid_registration_data + test_client, mock_auth_token, mocker, valid_registration_data ): """Test registration with duplicate user""" mock_response = MagicMock() mock_response.status_code = 409 mocker.patch("httpx.AsyncClient.post", return_value=mock_response) - response = client_with_settings_override.post( + response = test_client.post( "/bpa/register", json=valid_registration_data ) @@ -82,7 +82,7 @@ def test_registration_duplicate_user( def test_registration_auth0_error( - client_with_settings_override, mock_auth_token, mocker, valid_registration_data + test_client, mock_auth_token, mocker, valid_registration_data ): """Test registration with Auth0 API error""" mock_response = MagicMock() @@ -90,7 +90,7 @@ def test_registration_auth0_error( mock_response.text = "Invalid request" mocker.patch("httpx.AsyncClient.post", return_value=mock_response) - response = client_with_settings_override.post( + response = test_client.post( "/bpa/register", json=valid_registration_data ) @@ -99,19 +99,19 @@ def test_registration_auth0_error( def test_registration_with_invalid_organization( - client_with_settings_override, mock_auth_token, mocker, valid_registration_data + test_client, mock_auth_token, mocker, valid_registration_data ): """Test registration with invalid organization ID""" data = valid_registration_data.copy() data["organizations"] = {"invalid-org-id": True} - response = client_with_settings_override.post("/bpa/register", json=data) + response = test_client.post("/bpa/register", json=data) assert response.status_code == 400 assert "Invalid organization ID" in response.json()["detail"] -def test_registration_request_validation(client_with_settings_override): +def test_registration_request_validation(test_client): """Test request validation""" invalid_data = { "username": "testuser", @@ -119,13 +119,13 @@ def test_registration_request_validation(client_with_settings_override): "organizations": {}, } - response = client_with_settings_override.post("/bpa/register", json=invalid_data) + response = test_client.post("/bpa/register", json=invalid_data) assert response.status_code == 422 def test_no_selected_organizations( - client_with_settings_override, mock_auth_token, mocker, valid_registration_data + test_client, mock_auth_token, mocker, valid_registration_data ): """Test registration with no organizations selected""" data = valid_registration_data.copy() @@ -141,7 +141,7 @@ def test_no_selected_organizations( mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) - response = client_with_settings_override.post("/bpa/register", json=data) + response = test_client.post("/bpa/register", json=data) assert response.status_code == 200 called_data = mock_post.call_args[1]["json"] @@ -150,7 +150,7 @@ def test_no_selected_organizations( def test_empty_organizations_dict( - client_with_settings_override, mock_auth_token, mocker, valid_registration_data + test_client, mock_auth_token, mocker, valid_registration_data ): """Test registration with empty organizations dictionary""" data = valid_registration_data.copy() @@ -162,7 +162,7 @@ def test_empty_organizations_dict( mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) - response = client_with_settings_override.post("/bpa/register", json=data) + response = test_client.post("/bpa/register", json=data) assert response.status_code == 200 called_data = mock_post.call_args[1]["json"] @@ -171,20 +171,20 @@ def test_empty_organizations_dict( def test_registration_email_format( - client_with_settings_override, valid_registration_data + test_client, valid_registration_data ): """Test email format validation""" data = valid_registration_data.copy() data["email"] = "invalid-email" - response = client_with_settings_override.post("/bpa/register", json=data) + response = test_client.post("/bpa/register", json=data) assert response.status_code == 422 assert "email" in response.json()["detail"][0]["loc"] def test_all_organizations_selected( - client_with_settings_override, + test_client, mock_auth_token, mock_settings, mocker, @@ -200,7 +200,7 @@ def test_all_organizations_selected( mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) - response = client_with_settings_override.post("/bpa/register", json=data) + response = test_client.post("/bpa/register", json=data) assert response.status_code == 200 called_data = mock_post.call_args[1]["json"] diff --git a/tests/test_galaxy.py b/tests/test_galaxy.py index d4755a79..11ce3898 100644 --- a/tests/test_galaxy.py +++ b/tests/test_galaxy.py @@ -33,11 +33,11 @@ def test_galaxy_registration_data_password_match(): public_name="valid_username") -def test_get_registration_token(client_with_settings_override, mock_settings): +def test_get_registration_token(test_client, mock_settings): """ Test get-registration-token endpoint returns a valid JWT token. """ - response = client_with_settings_override.get("/galaxy/get-registration-token") + response = test_client.get("/galaxy/get-registration-token") assert response.status_code == 200 jwt.decode(response.json()["token"], mock_settings.jwt_secret_key, algorithms=mock_settings.auth0_algorithms) @@ -78,7 +78,7 @@ def test_to_auth0_create_user_data_valid(): assert auth0_data.user_metadata.galaxy_username == "valid_username" -def test_register(mocker, mock_auth_token, mock_settings, client_with_settings_override): +def test_register(mocker, mock_auth_token, mock_settings, test_client): """ Try to test our register endpoint. Since we don't want to call an actual Auth0 API, test that: @@ -92,9 +92,9 @@ def test_register(mocker, mock_auth_token, mock_settings, client_with_settings_o mock_resp.status_code = 201 mock_post = mocker.patch("httpx.post", return_value=mock_resp) user_data = GalaxyRegistrationDataFactory.build() - token_resp = client_with_settings_override.get("/galaxy/get-registration-token") + token_resp = test_client.get("/galaxy/get-registration-token") headers = {"registration-token": token_resp.json()["token"]} - resp = client_with_settings_override.post("/galaxy/register", json=user_data.model_dump(), headers=headers) + resp = test_client.post("/galaxy/register", json=user_data.model_dump(), headers=headers) assert resp.status_code == 200 assert resp.json()["message"] == "User registered successfully" assert resp.json()["user"] == {"user_id": "abc123"} @@ -108,8 +108,8 @@ def test_register(mocker, mock_auth_token, mock_settings, client_with_settings_o ) -def test_register_requires_token(client_with_settings_override): +def test_register_requires_token(test_client): user_data = GalaxyRegistrationDataFactory.build() - resp = client_with_settings_override.post("/galaxy/register", json=user_data.model_dump()) + resp = test_client.post("/galaxy/register", json=user_data.model_dump()) assert resp.status_code == 400 assert resp.json()["detail"] == "Missing registration token" diff --git a/tests/test_main.py b/tests/test_main.py index edf35b97..2ede746c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,7 +5,7 @@ client = TestClient(app) -def test_root(client_with_settings_override): - response = client_with_settings_override.get("/") +def test_root(test_client): + response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"message": "AAI Backend API"} diff --git a/tests/test_user.py b/tests/test_user.py index dbed25dd..8cc22be6 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -2,16 +2,10 @@ import pytest from fastapi import HTTPException -from fastapi.testclient import TestClient -from main import app from schemas.service import AppMetadata, Group, Resource, Service -from schemas.tokens import AccessTokenPayload -from schemas.user import User from tests.datagen import AccessTokenPayloadFactory, Auth0UserFactory -client = TestClient(app) - # --- Test Fixtures --- @pytest.fixture @@ -78,16 +72,17 @@ def mock_user_data(): "/me/all/pending", ], ) -def test_endpoints_require_auth(endpoint): +def test_endpoints_require_auth(endpoint, test_client): """Test that all endpoints require authentication""" - response = client.get(endpoint) + response = test_client.get(endpoint) assert response.status_code == 401 assert response.json() == {"detail": "Not authenticated"} # --- Service Endpoints (GET) --- def test_get_all_services( - mock_auth_token, auth_headers, mock_user_data, mocker + mock_auth_token, auth_headers, mock_user_data, mocker, + test_client ): """Test getting all services""" mocker.patch( @@ -98,7 +93,7 @@ def test_get_all_services( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/services", headers=auth_headers) + response = test_client.get("/me/services", headers=auth_headers) assert response.status_code == 200 expected_services = [ @@ -108,7 +103,8 @@ def test_get_all_services( def test_get_approved_services( - mock_auth_token, auth_headers, mock_user_data, mocker + mock_auth_token, auth_headers, mock_user_data, mocker, + test_client ): """Test getting approved services""" mocker.patch( @@ -119,7 +115,7 @@ def test_get_approved_services( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/services/approved", headers=auth_headers) + response = test_client.get("/me/services/approved", headers=auth_headers) assert response.status_code == 200 approved_services = [ @@ -131,7 +127,8 @@ def test_get_approved_services( def test_get_pending_services( - mock_auth_token, auth_headers, mock_user_data, mocker + mock_auth_token, auth_headers, mock_user_data, mocker, + test_client ): """Test getting pending services""" mocker.patch( @@ -142,7 +139,7 @@ def test_get_pending_services( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/services/pending", headers=auth_headers) + response = test_client.get("/me/services/pending", headers=auth_headers) assert response.status_code == 200 pending_services = [ @@ -154,7 +151,8 @@ def test_get_pending_services( def test_get_services_failed_fetch( - mock_auth_token, auth_headers, mocker + mock_auth_token, auth_headers, mocker, + test_client ): """Test handling of failed API calls""" mocker.patch( @@ -165,13 +163,13 @@ def test_get_services_failed_fetch( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/services", headers=auth_headers) + response = test_client.get("/me/services", headers=auth_headers) assert response.status_code == 403 assert response.json() == {"detail": "Failed to fetch user data"} def test_get_services_empty_metadata( - mock_auth_token, auth_headers, mocker + mock_auth_token, auth_headers, mocker, test_client ): """Test handling of empty metadata""" empty_user = Auth0UserFactory.build( @@ -182,13 +180,14 @@ def test_get_services_empty_metadata( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/services", headers=auth_headers) + response = test_client.get("/me/services", headers=auth_headers) assert response.status_code == 200 assert response.json() == {"services": []} def test_get_services_no_metadata( - mock_auth_token, auth_headers, mocker + mock_auth_token, auth_headers, mocker, + test_client ): """Test handling of missing metadata""" no_metadata_user = Auth0UserFactory.build(app_metadata=AppMetadata()) @@ -197,14 +196,15 @@ def test_get_services_no_metadata( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/services", headers=auth_headers) + response = test_client.get("/me/services", headers=auth_headers) assert response.status_code == 200 assert response.json() == {"services": []} # --- Resource Endpoints (GET) --- def test_get_all_resources( - mock_auth_token, auth_headers, mock_user_data, mocker + mock_auth_token, auth_headers, mock_user_data, mocker, + test_client ): """Test getting all resources""" mocker.patch( @@ -215,7 +215,7 @@ def test_get_all_resources( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/resources", headers=auth_headers) + response = test_client.get("/me/resources", headers=auth_headers) assert response.status_code == 200 all_resources = [ r.model_dump() @@ -226,7 +226,8 @@ def test_get_all_resources( def test_get_approved_resources( - mock_auth_token, auth_headers, mock_user_data, mocker + mock_auth_token, auth_headers, mock_user_data, mocker, + test_client ): """Test getting approved resources""" mocker.patch( @@ -237,7 +238,7 @@ def test_get_approved_resources( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/resources/approved", headers=auth_headers) + response = test_client.get("/me/resources/approved", headers=auth_headers) assert response.status_code == 200 approved_resources = [ r.model_dump() @@ -249,7 +250,7 @@ def test_get_approved_resources( def test_get_resources_empty_metadata( - mock_auth_token, auth_headers, mocker + mock_auth_token, auth_headers, mocker, test_client ): """Test handling of empty resource metadata""" empty_user = Auth0UserFactory.build(app_metadata=AppMetadata(services=[], groups=[]), @@ -259,13 +260,14 @@ def test_get_resources_empty_metadata( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/resources", headers=auth_headers) + response = test_client.get("/me/resources", headers=auth_headers) assert response.status_code == 200 assert response.json() == {"resources": []} def test_get_resources_no_metadata( - mock_auth_token, auth_headers, mocker + mock_auth_token, auth_headers, mocker, + test_client ): """Test handling of missing resource metadata""" no_metadata_user = Auth0UserFactory.build(app_metadata=AppMetadata()) @@ -274,7 +276,7 @@ def test_get_resources_no_metadata( "routers.user.get_management_token", return_value="mock_management_token" ) - response = client.get("/me/resources", headers=auth_headers) + response = test_client.get("/me/resources", headers=auth_headers) assert response.status_code == 200 assert response.json() == {"resources": []} @@ -282,7 +284,7 @@ def test_get_resources_no_metadata( # --- Service Request Endpoints (POST) --- def test_request_service_success( mock_auth_token, auth_headers, mock_user_data, mocker, - client_with_settings_override + test_client ): """Test successful service request""" mocker.patch("routers.user.get_user_data", return_value=mock_user_data) @@ -297,7 +299,7 @@ def test_request_service_success( "user_id": mock_auth_token.sub, } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/service", json=new_service, headers=auth_headers ) assert response.status_code == 200 @@ -307,7 +309,7 @@ def test_request_service_success( def test_request_service_duplicate( mock_auth_token, auth_headers, mock_user_data, mocker, - client_with_settings_override + test_client ): """Test duplicate service request""" mocker.patch("routers.user.get_user_data", return_value=mock_user_data) @@ -321,7 +323,7 @@ def test_request_service_duplicate( "user_id": mock_auth_token.sub, } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/service", json=existing_service, headers=auth_headers ) assert response.status_code == 400 @@ -332,7 +334,7 @@ def test_request_service_duplicate( def test_request_service_user_mismatch( mock_auth_token, auth_headers, mock_user_data, - client_with_settings_override + test_client ): """Test service request with mismatched user""" request_payload = { @@ -341,7 +343,7 @@ def test_request_service_user_mismatch( "user_id": "auth0|WRONG_USER", } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/service", json=request_payload, headers=auth_headers ) assert response.status_code == 403 @@ -354,7 +356,7 @@ def test_request_service_user_mismatch( # --- Resource Request Endpoints (POST) --- def test_request_resource_success( mock_auth_token, auth_headers, mock_user_data, mocker, - client_with_settings_override + test_client ): """Test successful resource request""" mocker.patch("routers.user.get_user_data", return_value=mock_user_data) @@ -370,7 +372,7 @@ def test_request_resource_success( "service_id": "service1", } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/service1/resource-new", json=request_payload, headers=auth_headers ) assert response.status_code == 200 @@ -379,7 +381,7 @@ def test_request_resource_success( def test_request_resource_user_mismatch( mock_auth_token, auth_headers, mock_user_data, - client_with_settings_override + test_client ): """Test resource request with mismatched user""" request_payload = { @@ -389,7 +391,7 @@ def test_request_resource_user_mismatch( "service_id": "service1", } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/service1/res-invalid", json=request_payload, headers=auth_headers ) assert response.status_code == 403 @@ -401,7 +403,7 @@ def test_request_resource_user_mismatch( def test_request_resource_non_approved_service( mock_auth_token, auth_headers, mock_user_data, mocker, - client_with_settings_override + test_client ): """Test resource request for non-approved service""" mocker.patch("routers.user.get_user_data", return_value=mock_user_data) @@ -416,7 +418,7 @@ def test_request_resource_non_approved_service( "service_id": "service2", } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/service2/blocked-resource", json=request_payload, headers=auth_headers, @@ -430,7 +432,7 @@ def test_request_resource_non_approved_service( def test_request_resource_duplicate( mock_auth_token, auth_headers, mock_user_data, mocker, - client_with_settings_override + test_client ): """Test duplicate resource request""" mocker.patch("routers.user.get_user_data", return_value=mock_user_data) @@ -445,7 +447,7 @@ def test_request_resource_duplicate( "service_id": "service1", } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/service1/resource1", json=existing_resource, headers=auth_headers ) assert response.status_code == 400 @@ -456,7 +458,7 @@ def test_request_resource_duplicate( def test_request_resource_invalid_service( mock_auth_token, auth_headers, mock_user_data, mocker, - client_with_settings_override + test_client ): """Test resource request for non-existent service""" mocker.patch("routers.user.get_user_data", return_value=mock_user_data) @@ -471,74 +473,10 @@ def test_request_resource_invalid_service( "service_id": "non-existent-service", } - response = client_with_settings_override.post( + response = test_client.post( "/me/request/non-existent-service/resource-invalid", json=request_payload, headers=auth_headers, ) assert response.status_code == 404 assert response.json()["detail"] == "Service with ID non-existent-service not found" - -def test_is_admin_true(): - payload = AccessTokenPayload( - exp=9999999999, - iat=9999999000, - iss="https://example.com/", - sub="abc123", - aud=["client_id"], - scope="read:all", - **{ - "biocommons.org.au/roles": ["PlatformAdmin"], - "permissions": ["read"] - } - ) - user = User(access_token=payload) - assert user.is_admin() is True - -def test_is_admin_false(): - payload = AccessTokenPayload( - exp=9999999999, - iat=9999999000, - iss="https://example.com/", - sub="abc123", - aud=["client_id"], - scope="read:all", - **{ - "biocommons.org.au/roles": ["guest"], - "permissions": ["read"] - } - ) - user = User(access_token=payload) - assert user.is_admin() is False - -def test_is_admin_empty_roles(): - payload = AccessTokenPayload( - exp=9999999999, - iat=9999999000, - iss="https://example.com/", - sub="abc123", - aud=["client_id"], - scope="read:all", - **{ - "biocommons.org.au/roles": [], - "permissions": ["read"] - } - ) - user = User(access_token=payload) - assert user.is_admin() is False - -def test_is_admin_multiple_roles_with_admin(): - payload = AccessTokenPayload( - exp=9999999999, - iat=9999999000, - iss="https://example.com/", - sub="abc123", - aud=["client_id"], - scope="read:all", - **{ - "biocommons.org.au/roles": ["editor", "admin_viewer"], - "permissions": ["read"] - } - ) - user = User(access_token=payload) - assert user.is_admin() is True diff --git a/tests/test_user_schema.py b/tests/test_user_schema.py index b8d83a69..16047fa2 100644 --- a/tests/test_user_schema.py +++ b/tests/test_user_schema.py @@ -1,35 +1,26 @@ -from schemas.tokens import AccessTokenPayload from schemas.user import User +from tests.datagen import AccessTokenPayloadFactory -def test_is_admin_true(): - payload = AccessTokenPayload( - exp=9999999999, - iat=9999999000, - iss="https://example.com/", - sub="abc123", - aud=["client_id"], - scope="read:all", - **{ - "biocommons.org.au/roles": ["PlatformAdmin"], # ✅ match schema key - "permissions": ["read"] - } - ) +def test_is_admin_true(mock_settings): + payload = AccessTokenPayloadFactory.build(biocommons_roles=["Admin"]) user = User(access_token=payload) - assert user.is_admin() is True - -def test_is_admin_false(): - payload = AccessTokenPayload( - exp=9999999999, - iat=9999999000, - iss="https://example.com/", - sub="abc123", - aud=["client_id"], - scope="read:all", - **{ - "biocommons.org.au/roles": ["guest"], # ✅ match schema key - "permissions": ["read"] - } - ) + assert user.is_admin(settings=mock_settings) is True + + +def test_is_admin_false(mock_settings): + payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"]) + user = User(access_token=payload) + assert user.is_admin(settings=mock_settings) is False + + +def test_is_admin_empty_roles(mock_settings): + payload = AccessTokenPayloadFactory.build(biocommons_roles=[]) + user = User(access_token=payload) + assert user.is_admin(settings=mock_settings) is False + + +def test_is_admin_multiple_roles_with_admin(mock_settings): + payload = AccessTokenPayloadFactory.build(biocommons_roles=["Admin", "Editor"]) user = User(access_token=payload) - assert user.is_admin() is False + assert user.is_admin(settings=mock_settings) is True