Skip to content
Merged
51 changes: 37 additions & 14 deletions auth0/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__all__ = ["Auth0Client"]

from typing import Optional

import httpx

from auth0.schemas import Auth0UserResponse
Expand All @@ -17,33 +19,54 @@ def _convert_users(resp: httpx.Response):
"""Convert a list of Auth0UserResponse objects from a response."""
return [Auth0UserResponse(**user) for user in resp.json()]

def get_users(self) -> list[Auth0UserResponse]:
def get_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]:
params = {}
if page is not None:
# Convert from 1-based pagination to 0-based.
page = page - 1
params["page"] = page
if per_page is not None:
params["per_page"] = per_page
url = f"https://{self.domain}/api/v2/users"
resp = self._client.get(url)
resp = self._client.get(url, params=params)
return self._convert_users(resp)

def get_user(self, user_id: str) -> Auth0UserResponse:
url = f"https://{self.domain}/api/v2/users/{user_id}"
resp = self._client.get(url)
return Auth0UserResponse(**resp.json())

def get_approved_users(self) -> list[Auth0UserResponse]:
# TODO: also search for approved resources? (with OR)
approved_query = 'app_metadata.services.status:"approved"'
def _search_users(self, query: str, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]:
params = {"q": query, "search_engine": "v3"}
if page is not None:
# Convert from 1-based pagination to 0-based.
page = page - 1
params["page"] = page
if per_page is not None:
params["per_page"] = per_page
url = f"https://{self.domain}/api/v2/users"
# TODO: set primary_order=false for faster search?
# https://auth0.com/docs/manage-users/user-search/user-search-best-practices
resp = self._client.get(url, params={"q": approved_query, "search_engine": "v3"})
resp = self._client.get(
url,
params={
"q": query,
"search_engine": "v3",
"page": page,
"per_page": per_page
}
)
return self._convert_users(resp)

def get_pending_users(self) -> list[Auth0UserResponse]:
def get_approved_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]:
# TODO: also search for approved resources? (with OR)
approved_query = 'app_metadata.services.status:"approved"'
return self._search_users(approved_query, page, per_page)

def get_pending_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]:
pending_query = 'app_metadata.services.status:"pending"'
url = f"https://{self.domain}/api/v2/users"
resp = self._client.get(url, params={"q": pending_query, "search_engine": "v3"})
return [Auth0UserResponse(**user) for user in resp.json()]
return self._search_users(pending_query, page, per_page)

def get_revoked_users(self) -> list[Auth0UserResponse]:
def get_revoked_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]:
revoked_query = 'app_metadata.services.status:"revoked"'
url = f"https://{self.domain}/api/v2/users"
resp = self._client.get(url, params={"q": revoked_query, "search_engine": "v3"})
return [Auth0UserResponse(**user) for user in resp.json()]
return self._search_users(revoked_query, page, per_page)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dev = [
"polyfactory>=2.21.0",
"pre-commit>=3.7.0",
"freezegun>=1.5.2",
"respx>=0.22.0",
]

[tool.pytest.ini_options]
Expand Down
47 changes: 38 additions & 9 deletions routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
from typing import Annotated

from fastapi import APIRouter, Depends, Path
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.params import Query
from pydantic import BaseModel, ValidationError

from auth.config import Settings, get_settings
from auth.management import get_management_token
Expand All @@ -19,6 +21,29 @@
ServiceIdParam = Path(..., pattern=r"^[-a-zA-Z0-9_]+$")
ResourceIdParam = Path(..., pattern=r"^[-a-zA-Z0-9_]+$")


class PaginationParams(BaseModel):
"""
Query parameters for paginated endpoints. Page starts at 1.
"""
page: int = Query(1, ge=1)
per_page: int = Query(100, ge=1, le=100)

@property
def start_index(self):
return (self.page - 1) * self.per_page


def get_pagination_params(page: int = 1, per_page: int = 100):
try:
return PaginationParams(page=page, per_page=per_page)
except ValidationError:
raise HTTPException(
status_code=422,
detail="Invalid page params: page should be >= 1, per_page should be >= 1 and <= 100"
)


router = APIRouter(prefix="/admin", tags=["admin"],
dependencies=[Depends(user_is_admin)])

Expand All @@ -32,27 +57,31 @@ def get_auth0_client(settings: Settings = Depends(get_settings),
# of them
@router.get("/users",
response_model=list[Auth0UserResponse])
def get_users(client: Auth0Client = Depends(get_auth0_client)):
resp = client.get_users()
def get_users(client: Annotated[Auth0Client, Depends(get_auth0_client)],
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
resp = client.get_users(page=pagination.page, per_page=pagination.per_page)
return resp


# NOTE: This must appear before /users/{user_id} so it takes precedence
@router.get("/users/approved")
def get_approved_users(client: Annotated[Auth0Client, Depends(get_auth0_client)]):
resp = client.get_approved_users()
def get_approved_users(client: Annotated[Auth0Client, Depends(get_auth0_client)],
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
resp = client.get_approved_users(page=pagination.page, per_page=pagination.per_page)
return resp


@router.get("/users/pending")
def get_pending_users(client: Annotated[Auth0Client, Depends(get_auth0_client)]):
resp = client.get_pending_users()
def get_pending_users(client: Annotated[Auth0Client, Depends(get_auth0_client)],
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
resp = client.get_pending_users(page=pagination.page, per_page=pagination.per_page)
return resp


@router.get("/users/revoked")
def get_revoked_users(client: Annotated[Auth0Client, Depends(get_auth0_client)]):
resp = client.get_revoked_users()
def get_revoked_users(client: Annotated[Auth0Client, Depends(get_auth0_client)],
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
resp = client.get_revoked_users(page=pagination.page, per_page=pagination.per_page)
return resp


Expand Down
35 changes: 33 additions & 2 deletions tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from auth.validator import get_current_user, user_is_admin
from main import app
from routers.admin import PaginationParams
from schemas import Resource, Service
from tests.datagen import (
AccessTokenPayloadFactory,
Expand All @@ -33,6 +34,15 @@ def mock_auth0_client(mocker):
return mock_client()


def test_pagination_params_start_index():
"""
Test we can get the current start index given the page number and per_page.
"""
params = PaginationParams(page=2, per_page=10)
# start index for page 1 is 0, for page 2 is 0 + per_page = 10
assert params.start_index == 10


def test_get_users_requires_admin_unauthorized(test_client, mocker):
def get_nonadmin_user():
payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"])
Expand Down Expand Up @@ -67,6 +77,23 @@ def test_get_users(test_client, as_admin_user, mock_auth0_client):
assert len(resp.json()) == 3


def test_get_users_pagination_params(test_client, as_admin_user, mock_auth0_client):
users = Auth0UserResponseFactory.batch(3)
mock_auth0_client.get_users.return_value = users
resp = test_client.get("/admin/users?page=2&per_page=10")
assert resp.status_code == 200
assert len(resp.json()) == 3


def test_get_users_invalid_params(test_client, as_admin_user, mock_auth0_client):
users = Auth0UserResponseFactory.batch(3)
mock_auth0_client.get_users.return_value = users
resp = test_client.get("/admin/users?page=0&per_page=500")
assert resp.status_code == 422
error_msg = resp.json()["detail"]
assert "Invalid page params" in error_msg


def test_get_user(test_client, as_admin_user, mock_auth0_client):
user = Auth0UserResponseFactory.build()
mock_auth0_client.get_user.return_value = user
Expand Down Expand Up @@ -175,13 +202,15 @@ def test_revoke_service(test_client, as_admin_user, mock_auth0_client, mocker):

Note this is currently pretty clunky due to the need to mock out asyncio.run.
"""
# Build test user and metadata
resource1 = Resource(name="Test Resource", id="resource1", status="approved")
resource2 = Resource(name="Test Resource", id="resource2", status="approved")
service = Service(
name="Test Service",
id="service1",
status="approved",
last_updated=FROZEN_TIME - timedelta(hours=1),
updated_by=""
updated_by="",
resources=[resource1, resource2]
)
app_metadata = AppMetadataFactory.build(services=[service])
user = Auth0UserResponseFactory.build(app_metadata=app_metadata.model_dump(mode="json"))
Expand Down Expand Up @@ -215,6 +244,8 @@ def test_revoke_service(test_client, as_admin_user, mock_auth0_client, mocker):
assert service_data["status"] == "revoked"
assert service_data["id"] == service.id
assert service_data["updated_by"] == revoking_user.email
for resource in service_data["resources"]:
assert resource["status"] == "revoked"


def test_approve_resource(test_client, as_admin_user, mock_auth0_client, mocker):
Expand Down
81 changes: 81 additions & 0 deletions tests/test_auth0_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest
import respx
from httpx import Response

from auth0.client import Auth0Client
from tests.datagen import Auth0UserResponseFactory


@pytest.fixture
def auth0_client():
return Auth0Client(domain="example.auth0.com", management_token="dummy-token")


@respx.mock
def test_get_users_no_pagination(auth0_client):
user = Auth0UserResponseFactory.build()
route = respx.get("https://example.auth0.com/api/v2/users").mock(
return_value=Response(200, json=[user.model_dump(mode="json")])
)

result = auth0_client.get_users()

assert route.called
assert result[0].model_dump(mode="json") == user.model_dump(mode="json")


@respx.mock
def test_get_users_with_pagination(auth0_client):
user = Auth0UserResponseFactory.build()
route = respx.get("https://example.auth0.com/api/v2/users").respond(
200, json=[user.model_dump(mode="json")]
)

result = auth0_client.get_users(page=2, per_page=25)

# Validate the actual request
request = route.calls[0].request
assert route.called
assert request.url.params["page"] == "1"
assert request.url.params["per_page"] == "25"
assert result[0].model_dump(mode="json") == user.model_dump(mode="json")


@respx.mock
def test_get_user_by_id(auth0_client):
user_id = "auth0|789"
user = Auth0UserResponseFactory.build(user_id=user_id)
route = respx.get(f"https://example.auth0.com/api/v2/users/{user_id}").mock(
return_value=Response(200, json=user.model_dump(mode="json"))
)

result = auth0_client.get_user(user_id)

assert route.called
assert result.model_dump(mode="json") == user.model_dump(mode="json")


@pytest.mark.parametrize(
"method,query",
[
("get_approved_users", 'app_metadata.services.status:"approved"'),
("get_pending_users", 'app_metadata.services.status:"pending"'),
("get_revoked_users", 'app_metadata.services.status:"revoked"'),
]
)
@respx.mock
def test_search_users_methods(auth0_client, method, query):
user = Auth0UserResponseFactory.build()
route = respx.get("https://example.auth0.com/api/v2/users").respond(
200, json=[user.model_dump(mode="json")]
)

result = getattr(auth0_client, method)(page=3, per_page=50)

assert route.called
request = route.calls[0].request
assert request.url.params["q"] == query
assert request.url.params["search_engine"] == "v3"
assert request.url.params["page"] == "2"
assert request.url.params["per_page"] == "50"
assert result[0].model_dump(mode="json") == user.model_dump(mode="json")
14 changes: 14 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.