diff --git a/auth/validator.py b/auth/validator.py index e8a6bac6..695cef5d 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -8,7 +8,7 @@ from auth.config import Settings, get_settings from schemas.tokens import AccessTokenPayload -from schemas.user import User +from schemas.user import SessionUser oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -58,13 +58,13 @@ def get_rsa_key(token: str, settings: Settings) -> jwk.RSAKey | None: # type: i def get_current_user(token: str = Depends(oauth2_scheme), - settings: Settings = Depends(get_settings)) -> User: + settings: Settings = Depends(get_settings)) -> SessionUser: access_token = verify_jwt(token, settings=settings) - return User(access_token=access_token) + return SessionUser(access_token=access_token) -def user_is_admin(current_user: Annotated[User, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)]) -> User: +def user_is_admin(current_user: Annotated[SessionUser, Depends(get_current_user)], + settings: Annotated[Settings, Depends(get_settings)]) -> SessionUser: if not current_user.is_admin(settings=settings): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/auth0/client.py b/auth0/client.py index 59eb5b14..fbfaeec5 100644 --- a/auth0/client.py +++ b/auth0/client.py @@ -4,7 +4,7 @@ import httpx -from auth0.schemas import Auth0UserResponse +from schemas.biocommons import BiocommonsAuth0User class Auth0Client: @@ -17,9 +17,9 @@ def __init__(self, domain: str, management_token: str): @staticmethod def _convert_users(resp: httpx.Response): """Convert a list of Auth0UserResponse objects from a response.""" - return [Auth0UserResponse(**user) for user in resp.json()] + return [BiocommonsAuth0User(**user) for user in resp.json()] - def get_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]: + def get_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[BiocommonsAuth0User]: params = {} if page is not None: # Convert from 1-based pagination to 0-based. @@ -31,12 +31,12 @@ def get_users(self, page: Optional[int] = None, per_page: Optional[int] = None) resp = self._client.get(url, params=params) return self._convert_users(resp) - def get_user(self, user_id: str) -> Auth0UserResponse: + def get_user(self, user_id: str) -> BiocommonsAuth0User: url = f"https://{self.domain}/api/v2/users/{user_id}" resp = self._client.get(url) - return Auth0UserResponse(**resp.json()) + return BiocommonsAuth0User(**resp.json()) - def _search_users(self, query: str, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]: + def _search_users(self, query: str, page: Optional[int] = None, per_page: Optional[int] = None) -> list[BiocommonsAuth0User]: params = {"q": query, "search_engine": "v3"} if page is not None: # Convert from 1-based pagination to 0-based. @@ -58,15 +58,15 @@ def _search_users(self, query: str, page: Optional[int] = None, per_page: Option ) return self._convert_users(resp) - def get_approved_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]: + def get_approved_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[BiocommonsAuth0User]: # TODO: also search for approved resources? (with OR) approved_query = 'app_metadata.services.status:"approved"' return self._search_users(approved_query, page, per_page) - def get_pending_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]: + def get_pending_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[BiocommonsAuth0User]: pending_query = 'app_metadata.services.status:"pending"' return self._search_users(pending_query, page, per_page) - def get_revoked_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[Auth0UserResponse]: + def get_revoked_users(self, page: Optional[int] = None, per_page: Optional[int] = None) -> list[BiocommonsAuth0User]: revoked_query = 'app_metadata.services.status:"revoked"' return self._search_users(revoked_query, page, per_page) diff --git a/auth0/schemas.py b/auth0/schemas.py deleted file mode 100644 index 16eda409..00000000 --- a/auth0/schemas.py +++ /dev/null @@ -1,36 +0,0 @@ -from datetime import datetime -from typing import List, Optional - -from pydantic import BaseModel, ConfigDict, EmailStr, Field - -from schemas.service import AppMetadata - - -class Auth0UserResponse(BaseModel): - """ - Response returned by Auth0's /users endpoint. - Note we have our own Auth0User model - that includes specifying the metadata fields we use. - """ - user_id: str - email: EmailStr - email_verified: bool - username: Optional[str] = None - phone_number: Optional[str] = None - phone_verified: Optional[bool] = None - created_at: datetime - updated_at: datetime - identities: List[dict] - app_metadata: AppMetadata = Field(default_factory=AppMetadata) - user_metadata: Optional[dict] = None - picture: Optional[str] = None - name: Optional[str] = None - nickname: Optional[str] = None - last_ip: Optional[str] = None - last_login: Optional[datetime] = None - logins_count: Optional[int] = None - blocked: Optional[bool] = None - given_name: Optional[str] = None - family_name: Optional[str] = None - - model_config = ConfigDict(extra="allow") diff --git a/routers/admin.py b/routers/admin.py index f411ed13..a55a8738 100644 --- a/routers/admin.py +++ b/routers/admin.py @@ -10,9 +10,9 @@ from auth.management import get_management_token from auth.validator import get_current_user, user_is_admin from auth0.client import Auth0Client -from auth0.schemas import Auth0UserResponse from routers.user import update_user_metadata -from schemas.user import User +from schemas.biocommons import BiocommonsAuth0User +from schemas.user import SessionUser logger = logging.getLogger('uvicorn.error') @@ -53,10 +53,8 @@ def get_auth0_client(settings: Settings = Depends(get_settings), return Auth0Client(settings.auth0_domain, management_token=management_token) -# TODO: May need to paginate this response to make sure we get all -# of them @router.get("/users", - response_model=list[Auth0UserResponse]) + response_model=list[BiocommonsAuth0User],) def get_users(client: Annotated[Auth0Client, Depends(get_auth0_client)], pagination: Annotated[PaginationParams, Depends(get_pagination_params)]): resp = client.get_users(page=pagination.page, per_page=pagination.per_page) @@ -86,7 +84,7 @@ def get_revoked_users(client: Annotated[Auth0Client, Depends(get_auth0_client)], @router.get("/users/{user_id}", - response_model=Auth0UserResponse) + response_model=BiocommonsAuth0User) def get_user(user_id: Annotated[str, UserIdParam], client: Annotated[Auth0Client, Depends(get_auth0_client)]): return client.get_user(user_id) @@ -96,7 +94,7 @@ def get_user(user_id: Annotated[str, UserIdParam], def approve_service(user_id: Annotated[str, UserIdParam], service_id: Annotated[str, ServiceIdParam], client: Annotated[Auth0Client, Depends(get_auth0_client)], - approving_user: Annotated[User, Depends(get_current_user)]): + approving_user: Annotated[SessionUser, Depends(get_current_user)]): user = client.get_user(user_id=user_id) # Need to fetch full user info currently to get email address, not in access token approving_user_data = client.get_user(user_id=approving_user.access_token.sub) @@ -118,7 +116,7 @@ def approve_service(user_id: Annotated[str, UserIdParam], def revoke_service(user_id: Annotated[str, UserIdParam], service_id: Annotated[str, ServiceIdParam], client: Annotated[Auth0Client, Depends(get_auth0_client)], - revoking_user: Annotated[User, Depends(get_current_user)]): + revoking_user: Annotated[SessionUser, Depends(get_current_user)]): """ Revoke a service and all associated resources for a user. """ diff --git a/routers/bpa_register.py b/routers/bpa_register.py index e1585e9b..21214411 100644 --- a/routers/bpa_register.py +++ b/routers/bpa_register.py @@ -3,25 +3,16 @@ from fastapi import APIRouter, Depends, HTTPException from httpx import AsyncClient -from pydantic import BaseModel, EmailStr from auth.config import Settings, get_settings from auth.management import get_management_token -from schemas.bpa import BPARegisterData +from schemas.biocommons import BiocommonsRegisterData +from schemas.bpa import BPARegistrationRequest from schemas.service import Resource, Service router = APIRouter(prefix="/bpa", tags=["bpa", "registration"]) -class BPARegistrationRequest(BaseModel): - username: str - fullname: str - email: EmailStr - reason: str - password: str - organizations: Dict[str, bool] - - @router.post( "/register", response_model=Dict[str, Any], @@ -64,7 +55,7 @@ async def register_bpa_user( ) # Create Auth0 user data - user_data = BPARegisterData.from_registration( + user_data = BiocommonsRegisterData.from_bpa_registration( registration=registration, bpa_service=bpa_service ) diff --git a/routers/galaxy_register.py b/routers/galaxy_register.py index 61080763..a2cbd7fd 100644 --- a/routers/galaxy_register.py +++ b/routers/galaxy_register.py @@ -8,6 +8,7 @@ from auth.config import Settings, get_settings from auth.management import get_management_token from register.tokens import create_registration_token, verify_registration_token +from schemas.biocommons import BiocommonsRegisterData from schemas.galaxy import GalaxyRegistrationData logger = logging.getLogger(__name__) @@ -38,7 +39,7 @@ def register( logger.debug("Getting management token.") management_token = get_management_token(settings=settings) headers = {"Authorization": f"Bearer {management_token}"} - user_data = registration_data.to_auth0_create_user_data() + user_data = BiocommonsRegisterData.from_galaxy_registration(registration_data) logger.debug("Registering with Auth0 management API") resp = httpx.post(url, json=user_data.model_dump(), headers=headers) if resp.status_code != 201: diff --git a/routers/user.py b/routers/user.py index 5aed3508..36dfede2 100644 --- a/routers/user.py +++ b/routers/user.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Any, Dict, List +from typing import Annotated, Any, Dict, List from fastapi import APIRouter, Depends, HTTPException from httpx import AsyncClient @@ -7,16 +7,17 @@ from auth.config import Settings, get_settings from auth.management import get_management_token from auth.validator import get_current_user +from schemas.biocommons import BiocommonsAuth0User from schemas.requests import ResourceRequest, ServiceRequest -from schemas.service import Auth0User, Resource, Service -from schemas.user import User +from schemas.service import Resource, Service +from schemas.user import SessionUser router = APIRouter( prefix="/me", tags=["user"], responses={401: {"description": "Unauthorized"}} ) -async def get_user_data(user: User, settings: Settings) -> Auth0User: +async def get_user_data(user: SessionUser, settings: Annotated[Settings, Depends(get_settings)]) -> BiocommonsAuth0User: """Fetch and return user data from Auth0.""" url = f"https://{settings.auth0_domain}/api/v2/users/{user.access_token.sub}" token = get_management_token(settings=settings) @@ -30,7 +31,7 @@ async def get_user_data(user: User, settings: Settings) -> Auth0User: status_code=403, detail="Failed to fetch user data", ) - return Auth0User(**response.json()) + return BiocommonsAuth0User(**response.json()) except HTTPException: raise except Exception as e: @@ -69,49 +70,49 @@ async def update_user_metadata( @router.get("/services", response_model=Dict[str, List[Service]]) -async def get_services(user: User = Depends(get_current_user)): +async def get_services(user: Annotated[SessionUser, Depends(get_current_user)]): """Get all services for the authenticated user.""" user_data = await get_user_data(user) return {"services": user_data.app_metadata.services} @router.get("/services/approved", response_model=Dict[str, List[Service]]) -async def get_approved_services(user: User = Depends(get_current_user)): +async def get_approved_services(user: Annotated[SessionUser, Depends(get_current_user)]): """Get approved services for the authenticated user.""" user_data = await get_user_data(user) return {"approved_services": user_data.approved_services} @router.get("/services/pending", response_model=Dict[str, List[Service]]) -async def get_pending_services(user: User = Depends(get_current_user)): +async def get_pending_services(user: Annotated[SessionUser, Depends(get_current_user)]): """Get pending services for the authenticated user.""" user_data = await get_user_data(user) return {"pending_services": user_data.pending_services} @router.get("/resources", response_model=Dict[str, List[Resource]]) -async def get_resources(user: User = Depends(get_current_user)): +async def get_resources(user: Annotated[SessionUser, Depends(get_current_user)]): """Get all resources for the authenticated user.""" user_data = await get_user_data(user) return {"resources": user_data.app_metadata.get_all_resources()} @router.get("/resources/approved", response_model=Dict[str, List[Resource]]) -async def get_approved_resources(user: User = Depends(get_current_user)): +async def get_approved_resources(user: Annotated[SessionUser, Depends(get_current_user)]): """Get approved resources for the authenticated user.""" user_data = await get_user_data(user) return {"approved_resources": user_data.approved_resources} @router.get("/resources/pending", response_model=Dict[str, List[Resource]]) -async def get_pending_resources(user: User = Depends(get_current_user)): +async def get_pending_resources(user: Annotated[SessionUser, Depends(get_current_user)]): """Get pending resources for the authenticated user.""" user_data = await get_user_data(user) return {"pending_resources": user_data.pending_resources} @router.get("/all/pending", response_model=Dict[str, List[Any]]) -async def get_all_pending(user: User = Depends(get_current_user)): +async def get_all_pending(user: Annotated[SessionUser, Depends(get_current_user)]): """Get all pending services and resources.""" user_data = await get_user_data(user) return { @@ -130,8 +131,8 @@ async def get_all_pending(user: User = Depends(get_current_user)): }, ) async def request_service( - service_request: ServiceRequest, user: User = Depends(get_current_user), - settings: Settings = Depends(get_settings), + service_request: ServiceRequest, user: Annotated[SessionUser, Depends(get_current_user)], + settings: Annotated[Settings, Depends(get_settings)] ) -> Dict[str, Any]: """Submit a request for a service.""" if user.access_token.sub != service_request.user_id: @@ -184,7 +185,7 @@ async def request_resource( service_id: str, resource_id: str, resource_request: ResourceRequest, - user: User = Depends(get_current_user), + user: Annotated[SessionUser, Depends(get_current_user)], settings: Settings = Depends(get_settings), ) -> Dict[str, Any]: """Submit a request for a resource within a service.""" diff --git a/schemas/__init__.py b/schemas/__init__.py index 6e83334f..5a03bd77 100644 --- a/schemas/__init__.py +++ b/schemas/__init__.py @@ -1,4 +1,4 @@ -from .group import Group -from .service import Resource, Service +from schemas.group import Group +from schemas.service import Resource, Service __all__ = ["Service", "Resource", "Group"] diff --git a/schemas/biocommons.py b/schemas/biocommons.py new file mode 100644 index 00000000..080dfb44 --- /dev/null +++ b/schemas/biocommons.py @@ -0,0 +1,187 @@ +""" +Schemas for how we represent users in Auth0 for BioCommons. + +These are the core schemas we use for storing/representing users +and their metadata +""" +from datetime import datetime +from typing import List, Optional, Self + +from pydantic import BaseModel, EmailStr, Field, HttpUrl + +from schemas import Resource, Service +from schemas.bpa import BPARegistrationRequest +from schemas.galaxy import GalaxyRegistrationData +from schemas.service import Group, Identity + + +class BPAMetadata(BaseModel): + registration_reason: str + username: str + + +class BiocommonsUserMetadata(BaseModel): + """ + User metadata we use for user-changeable data + like preferred usernames + """ + bpa: Optional[BPAMetadata] = None + galaxy_username: Optional[str] = None + + +class BiocommonsAppMetadata(BaseModel): + """ + app_metadata we use to manage service/resource requests. + Note we expect all app_metadata from Auth0 to match this format + (if not empty). + """ + groups: List[Group] = Field(default_factory=list) + services: List[Service] = Field(default_factory=list) + + def get_pending_services(self) -> List[Service]: + """Get all pending services.""" + return [s for s in self.services if s.status == "pending"] + + def get_approved_services(self) -> List[Service]: + """Get all approved services.""" + return [s for s in self.services if s.status == "approved"] + + def get_all_resources(self) -> List[Resource]: + """Get all resources across services.""" + return [r for s in self.services for r in s.resources] + + def get_pending_resources(self) -> List[Resource]: + """Get all pending resources.""" + return [r for s in self.services for r in s.resources if r.status == "pending"] + + def get_approved_resources(self) -> List[Resource]: + """Get all approved resources.""" + return [r for s in self.services for r in s.resources if r.status == "approved"] + + def get_service_by_id(self, service_id: str) -> Optional[Service]: + """Get a service by its ID.""" + return next((s for s in self.services if s.id == service_id), None) + + def get_resource_by_id(self, service_id: str, resource_id: str) -> Optional[Resource]: + """Get a resource by its ID.""" + service = self.get_service_by_id(service_id) + if service: + return service.get_resource_by_id(resource_id) + else: + return None + + def approve_service(self, service_id: str, updated_by: str): + """Approve a service by its ID.""" + service = self.get_service_by_id(service_id) + if service: + service.approve(updated_by) + + def revoke_service(self, service_id: str, updated_by: str): + """Revoke a service by its ID.""" + service = self.get_service_by_id(service_id) + if service: + service.revoke(updated_by=updated_by) + + def approve_resource(self, service_id: str, resource_id: str): + """Approve a resource by its ID.""" + resource = self.get_resource_by_id(service_id=service_id, resource_id=resource_id) + if resource: + resource.approve() + return resource + else: + raise ValueError("Resource not found.") + + +class BiocommonsRegisterData(BaseModel): + """ + Data we send to the /api/v2/users endpoint to register a user + """ + email: EmailStr + email_verified: bool = False + password: str + connection: str = "Username-Password-Authentication" + username: str + name: Optional[str] = None + username: Optional[str] = None + user_metadata: BiocommonsUserMetadata + app_metadata: BiocommonsAppMetadata + + @classmethod + def from_bpa_registration( + cls, registration: BPARegistrationRequest, + bpa_service: Service) -> Self: + return cls( + email=registration.email, + password=registration.password, + username=registration.username, + name=registration.fullname, + user_metadata=BiocommonsUserMetadata( + bpa=BPAMetadata(registration_reason=registration.reason, + username=registration.username,), + ), + app_metadata=BiocommonsAppMetadata(services=[bpa_service]), + ) + + @classmethod + def from_galaxy_registration( + cls, + registration: GalaxyRegistrationData): + # Galaxy registration is approved automatically + galaxy_service = Service( + name="Galaxy Australia", + id="galaxy", + status="approved", + last_updated=datetime.now(), + updated_by="" + ) + return BiocommonsRegisterData( + email=registration.email, + user_metadata=BiocommonsUserMetadata(galaxy_username=registration.public_name), + password=registration.password, + email_verified=False, + connection="Username-Password-Authentication", + app_metadata=BiocommonsAppMetadata(services=[galaxy_service]), + ) + + +class BiocommonsAuth0User(BaseModel): + """ + Represents the user data we get back from Auth0 for Biocommons users + (with our user and app metadata, if defined). + """ + created_at: datetime + email: EmailStr + email_verified: bool + identities: List[Identity] + name: str + nickname: str + picture: HttpUrl + updated_at: datetime + user_id: str + # Auth0 will not include user/app metadata in the response when + # empty, so make it optional + user_metadata: Optional[BiocommonsUserMetadata] = None + app_metadata: Optional[BiocommonsAppMetadata] = None + last_ip: Optional[str] = None + last_login: Optional[datetime] = None + logins_count: Optional[int] = None + + @property + def pending_services(self) -> List[Service]: + """Get all services with pending status.""" + return self.app_metadata.get_pending_services() + + @property + def approved_services(self) -> List[Service]: + """Get all services with approved status.""" + return self.app_metadata.get_approved_services() + + @property + def pending_resources(self) -> List[Resource]: + """Get all resources with pending status across all services.""" + return self.app_metadata.get_pending_resources() + + @property + def approved_resources(self) -> List[Resource]: + """Get all resources with approved status across all services.""" + return self.app_metadata.get_approved_resources() diff --git a/schemas/bpa.py b/schemas/bpa.py index 8dd2f37d..5b9a5036 100644 --- a/schemas/bpa.py +++ b/schemas/bpa.py @@ -1,39 +1,12 @@ -from typing import Dict, List +from typing import Dict from pydantic import BaseModel, EmailStr -class BPAUserMetadata(BaseModel): - bpa: Dict[str, str] = {"registration_reason": ""} - - -class BPAAppMetadata(BaseModel): - groups: List[Dict] = [] - services: List[Dict] = [] - - -class BPARegisterData(BaseModel): +class BPARegistrationRequest(BaseModel): + username: str + fullname: str email: EmailStr + reason: str password: str - connection: str = "Username-Password-Authentication" - username: str - name: str - email_verified: bool = False - blocked: bool = False - verify_email: bool = True - user_metadata: BPAUserMetadata - app_metadata: BPAAppMetadata - - @classmethod - def from_registration(cls, registration, bpa_service): - """Create BPARegisterData from registration request and BPA service.""" - return cls( - email=registration.email, - password=registration.password, - username=registration.username, - name=registration.fullname, - user_metadata=BPAUserMetadata( - bpa={"registration_reason": registration.reason} - ), - app_metadata=BPAAppMetadata(services=[bpa_service.model_dump(mode="json")]), - ) + organizations: Dict[str, bool] diff --git a/schemas/galaxy.py b/schemas/galaxy.py index b286456a..b70f454d 100644 --- a/schemas/galaxy.py +++ b/schemas/galaxy.py @@ -14,29 +14,3 @@ def check_passwords_match(self) -> Self: if self.password != self.password_confirmation: raise ValueError('Passwords do not match') return self - - def to_auth0_create_user_data(self, - email_verified: bool=False, - connection: str = "Username-Password-Authentication") -> 'Auth0CreateUserData': - """ - Convert to the format expected by Auth0's create user endpoint - """ - return Auth0CreateUserData( - email=self.email, - user_metadata=Auth0UserMetadata(galaxy_username=self.public_name), - password=self.password, - email_verified=email_verified, - connection=connection, - ) - - -class Auth0UserMetadata(BaseModel): - galaxy_username: str - - -class Auth0CreateUserData(BaseModel): - email: EmailStr - user_metadata: Auth0UserMetadata - email_verified: bool = False - password: str - connection: str = 'Username-Password-Authentication' diff --git a/schemas/service.py b/schemas/service.py index 1139d8f8..398bd7f7 100644 --- a/schemas/service.py +++ b/schemas/service.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List, Literal, Optional -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, Field class Resource(BaseModel): @@ -54,108 +54,8 @@ class Group(BaseModel): id: str -class AppMetadata(BaseModel): - """ - app_metadata we use to manage service/resource requests. - Note we expect all app_metadata from Auth0 to match this format - (if not empty). - """ - groups: List[Group] = Field(default_factory=list) - services: List[Service] = Field(default_factory=list) - - def get_pending_services(self) -> List[Service]: - """Get all pending services.""" - return [s for s in self.services if s.status == "pending"] - - def get_approved_services(self) -> List[Service]: - """Get all approved services.""" - return [s for s in self.services if s.status == "approved"] - - def get_all_resources(self) -> List[Resource]: - """Get all resources across services.""" - return [r for s in self.services for r in s.resources] - - def get_pending_resources(self) -> List[Resource]: - """Get all pending resources.""" - return [r for s in self.services for r in s.resources if r.status == "pending"] - - def get_approved_resources(self) -> List[Resource]: - """Get all approved resources.""" - return [r for s in self.services for r in s.resources if r.status == "approved"] - - def get_service_by_id(self, service_id: str) -> Optional[Service]: - """Get a service by its ID.""" - return next((s for s in self.services if s.id == service_id), None) - - def get_resource_by_id(self, service_id: str, resource_id: str) -> Optional[Resource]: - """Get a resource by its ID.""" - service = self.get_service_by_id(service_id) - if service: - return service.get_resource_by_id(resource_id) - else: - return None - - def approve_service(self, service_id: str, updated_by: str): - """Approve a service by its ID.""" - service = self.get_service_by_id(service_id) - if service: - service.approve(updated_by) - - def revoke_service(self, service_id: str, updated_by: str): - """Revoke a service by its ID.""" - service = self.get_service_by_id(service_id) - if service: - service.revoke(updated_by=updated_by) - - def approve_resource(self, service_id: str, resource_id: str): - """Approve a resource by its ID.""" - resource = self.get_resource_by_id(service_id=service_id, resource_id=resource_id) - if resource: - resource.approve() - return resource - else: - raise ValueError("Resource not found.") - - class Identity(BaseModel): connection: str provider: str user_id: str isSocial: bool - - -class Auth0User(BaseModel): - created_at: datetime - email: str - email_verified: bool - identities: List[Identity] - name: str - nickname: str - picture: HttpUrl - updated_at: datetime - user_id: str - user_metadata: dict = Field(default_factory=dict) - app_metadata: AppMetadata - last_ip: Optional[str] = None - last_login: Optional[datetime] = None - logins_count: Optional[int] = None - - @property - def pending_services(self) -> List[Service]: - """Get all services with pending status.""" - return self.app_metadata.get_pending_services() - - @property - def approved_services(self) -> List[Service]: - """Get all services with approved status.""" - return self.app_metadata.get_approved_services() - - @property - def pending_resources(self) -> List[Resource]: - """Get all resources with pending status across all services.""" - return self.app_metadata.get_pending_resources() - - @property - def approved_resources(self) -> List[Resource]: - """Get all resources with approved status across all services.""" - return self.app_metadata.get_approved_resources() diff --git a/schemas/user.py b/schemas/user.py index 68f24e3e..76d03fb6 100644 --- a/schemas/user.py +++ b/schemas/user.py @@ -5,11 +5,12 @@ from .tokens import AccessTokenPayload -class User(BaseModel): +class SessionUser(BaseModel): """ - Define our user model so we can implement any required - permissions checks here, instead of doing individual - checks in different places. + Represents the current user of the AAI Portal, and their session data (e.g. access token). + + NOTE: doesn't represent a user in the Auth0 database - see the schemas + in schemas.biocommons for that """ access_token: AccessTokenPayload diff --git a/tests/conftest.py b/tests/conftest.py index 4d9d8e96..077fe7d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from auth.management import get_management_token from auth.validator import get_current_user from main import app -from tests.datagen import AccessTokenPayloadFactory, UserFactory +from tests.datagen import AccessTokenPayloadFactory, SessionUserFactory @pytest.fixture(autouse=True) @@ -61,7 +61,7 @@ def as_admin_user(): """ def override_user(): token = AccessTokenPayloadFactory.build(biocommons_roles=["Admin"]) - return UserFactory.build(access_token=token) + return SessionUserFactory.build(access_token=token) app.dependency_overrides[get_current_user] = override_user app.dependency_overrides[get_management_token] = lambda: "mock_token" diff --git a/tests/datagen.py b/tests/datagen.py index 1bd01196..3aa158a2 100644 --- a/tests/datagen.py +++ b/tests/datagen.py @@ -3,30 +3,26 @@ from polyfactory.decorators import post_generated from polyfactory.factories.pydantic_factory import ModelFactory -from auth0.schemas import Auth0UserResponse -from routers.bpa_register import BPARegistrationRequest +from schemas.biocommons import BiocommonsAppMetadata, BiocommonsAuth0User +from schemas.bpa import BPARegistrationRequest from schemas.galaxy import GalaxyRegistrationData -from schemas.service import AppMetadata, Auth0User from schemas.tokens import AccessTokenPayload -from schemas.user import User +from schemas.user import SessionUser class AccessTokenPayloadFactory(ModelFactory[AccessTokenPayload]): ... -class Auth0UserResponseFactory(ModelFactory[Auth0UserResponse]): +class SessionUserFactory(ModelFactory[SessionUser]): ... + + +class BiocommonsAuth0UserFactory(ModelFactory[BiocommonsAuth0User]): @classmethod def user_id(cls) -> str: return "auth0|" + ''.join(random.choices('0123456789abcdef', k=24)) -class UserFactory(ModelFactory[User]): ... - - -class Auth0UserFactory(ModelFactory[Auth0User]): ... - - class GalaxyRegistrationDataFactory(ModelFactory[GalaxyRegistrationData]): @post_generated @@ -51,4 +47,4 @@ def get_default_organizations(cls) -> dict: } -class AppMetadataFactory(ModelFactory[AppMetadata]): ... +class AppMetadataFactory(ModelFactory[BiocommonsAppMetadata]): ... diff --git a/tests/test_admin.py b/tests/test_admin.py index e8311e07..aaf32e6d 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -12,8 +12,8 @@ from tests.datagen import ( AccessTokenPayloadFactory, AppMetadataFactory, - Auth0UserResponseFactory, - UserFactory, + BiocommonsAuth0UserFactory, + SessionUserFactory, ) FROZEN_TIME = datetime(2025, 1, 1, 12, 0, 0) @@ -46,7 +46,7 @@ def test_pagination_params_start_index(): def test_get_users_requires_admin_unauthorized(test_client, mocker): def get_nonadmin_user(): payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"]) - return UserFactory.build(access_token=payload) + return SessionUserFactory.build(access_token=payload) app.dependency_overrides[get_current_user] = get_nonadmin_user mocker.patch("routers.admin.get_management_token", return_value="mock_token") @@ -58,19 +58,19 @@ def get_nonadmin_user(): def test_user_is_admin(mock_settings): payload = AccessTokenPayloadFactory.build(biocommons_roles=["Admin"]) - admin_user = UserFactory.build(access_token=payload) + admin_user = SessionUserFactory.build(access_token=payload) assert user_is_admin(current_user=admin_user, settings=mock_settings) def test_user_is_admin_nonadmin_user(mock_settings): payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"]) - user = UserFactory.build(access_token=payload) + user = SessionUserFactory.build(access_token=payload) with pytest.raises(HTTPException, match="You must be an admin to access this endpoint."): user_is_admin(current_user=user, settings=mock_settings) def test_get_users(test_client, as_admin_user, mock_auth0_client): - users = Auth0UserResponseFactory.batch(3) + users = BiocommonsAuth0UserFactory.batch(3) mock_auth0_client.get_users.return_value = users resp = test_client.get("/admin/users") assert resp.status_code == 200 @@ -78,7 +78,7 @@ def test_get_users(test_client, as_admin_user, mock_auth0_client): def test_get_users_pagination_params(test_client, as_admin_user, mock_auth0_client): - users = Auth0UserResponseFactory.batch(3) + users = BiocommonsAuth0UserFactory.batch(3) mock_auth0_client.get_users.return_value = users resp = test_client.get("/admin/users?page=2&per_page=10") assert resp.status_code == 200 @@ -86,7 +86,7 @@ def test_get_users_pagination_params(test_client, as_admin_user, mock_auth0_clie def test_get_users_invalid_params(test_client, as_admin_user, mock_auth0_client): - users = Auth0UserResponseFactory.batch(3) + users = BiocommonsAuth0UserFactory.batch(3) mock_auth0_client.get_users.return_value = users resp = test_client.get("/admin/users?page=0&per_page=500") assert resp.status_code == 422 @@ -95,7 +95,7 @@ def test_get_users_invalid_params(test_client, as_admin_user, mock_auth0_client) def test_get_user(test_client, as_admin_user, mock_auth0_client): - user = Auth0UserResponseFactory.build() + user = BiocommonsAuth0UserFactory.build() mock_auth0_client.get_user.return_value = user resp = test_client.get(f"/admin/users/{user.user_id}") assert resp.status_code == 200 @@ -103,7 +103,7 @@ def test_get_user(test_client, as_admin_user, mock_auth0_client): def test_get_approved_users(test_client, as_admin_user, mock_auth0_client): - approved_users = Auth0UserResponseFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "approved"}]}) + approved_users = BiocommonsAuth0UserFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "approved"}]}) mock_auth0_client.get_approved_users.return_value = approved_users resp = test_client.get("/admin/users/approved") assert resp.status_code == 200 @@ -115,7 +115,7 @@ def test_get_approved_users(test_client, as_admin_user, mock_auth0_client): def test_get_pending_users(test_client, as_admin_user, mock_auth0_client): - pending_users = Auth0UserResponseFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "pending"}]}) + pending_users = BiocommonsAuth0UserFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "pending"}]}) mock_auth0_client.get_pending_users.return_value = pending_users resp = test_client.get("/admin/users/pending") assert resp.status_code == 200 @@ -127,7 +127,7 @@ def test_get_pending_users(test_client, as_admin_user, mock_auth0_client): def test_get_revoked(test_client, as_admin_user, mock_auth0_client): - revoked_users = Auth0UserResponseFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "revoked"}]}) + revoked_users = BiocommonsAuth0UserFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "revoked"}]}) mock_auth0_client.get_revoked_users.return_value = revoked_users resp = test_client.get("/admin/users/revoked") assert resp.status_code == 200 @@ -163,8 +163,8 @@ def test_approve_service(test_client, as_admin_user, mock_auth0_client, mocker): updated_by="" ) app_metadata = AppMetadataFactory.build(services=[service]) - user = Auth0UserResponseFactory.build(app_metadata=app_metadata.model_dump(mode="json")) - approving_user = Auth0UserResponseFactory.build() + user = BiocommonsAuth0UserFactory.build(app_metadata=app_metadata.model_dump(mode="json")) + approving_user = BiocommonsAuth0UserFactory.build() # Mock Auth0 client behavior mock_auth0_client.get_user.side_effect = [user, approving_user] @@ -213,8 +213,8 @@ def test_revoke_service(test_client, as_admin_user, mock_auth0_client, mocker): resources=[resource1, resource2] ) app_metadata = AppMetadataFactory.build(services=[service]) - user = Auth0UserResponseFactory.build(app_metadata=app_metadata.model_dump(mode="json")) - revoking_user = Auth0UserResponseFactory.build() + user = BiocommonsAuth0UserFactory.build(app_metadata=app_metadata.model_dump(mode="json")) + revoking_user = BiocommonsAuth0UserFactory.build() # Mock Auth0 client behavior mock_auth0_client.get_user.side_effect = [user, revoking_user] @@ -263,7 +263,7 @@ def test_approve_resource(test_client, as_admin_user, mock_auth0_client, mocker) updated_by="" ) app_metadata = AppMetadataFactory.build(services=[service]) - user = Auth0UserResponseFactory.build(app_metadata=app_metadata.model_dump(mode="json")) + user = BiocommonsAuth0UserFactory.build(app_metadata=app_metadata.model_dump(mode="json")) # Mock Auth0 client behavior mock_auth0_client.get_user.return_value = user diff --git a/tests/test_auth0_client.py b/tests/test_auth0_client.py index 1889d8c2..1a8cdb2b 100644 --- a/tests/test_auth0_client.py +++ b/tests/test_auth0_client.py @@ -3,7 +3,7 @@ from httpx import Response from auth0.client import Auth0Client -from tests.datagen import Auth0UserResponseFactory +from tests.datagen import BiocommonsAuth0UserFactory @pytest.fixture @@ -13,7 +13,7 @@ def auth0_client(): @respx.mock def test_get_users_no_pagination(auth0_client): - user = Auth0UserResponseFactory.build() + user = BiocommonsAuth0UserFactory.build() route = respx.get("https://example.auth0.com/api/v2/users").mock( return_value=Response(200, json=[user.model_dump(mode="json")]) ) @@ -26,7 +26,7 @@ def test_get_users_no_pagination(auth0_client): @respx.mock def test_get_users_with_pagination(auth0_client): - user = Auth0UserResponseFactory.build() + user = BiocommonsAuth0UserFactory.build() route = respx.get("https://example.auth0.com/api/v2/users").respond( 200, json=[user.model_dump(mode="json")] ) @@ -44,7 +44,7 @@ def test_get_users_with_pagination(auth0_client): @respx.mock def test_get_user_by_id(auth0_client): user_id = "auth0|789" - user = Auth0UserResponseFactory.build(user_id=user_id) + user = BiocommonsAuth0UserFactory.build(user_id=user_id) route = respx.get(f"https://example.auth0.com/api/v2/users/{user_id}").mock( return_value=Response(200, json=user.model_dump(mode="json")) ) @@ -65,7 +65,7 @@ def test_get_user_by_id(auth0_client): ) @respx.mock def test_search_users_methods(auth0_client, method, query): - user = Auth0UserResponseFactory.build() + user = BiocommonsAuth0UserFactory.build() route = respx.get("https://example.auth0.com/api/v2/users").respond( 200, json=[user.model_dump(mode="json")] ) diff --git a/tests/test_galaxy.py b/tests/test_galaxy.py index 11ce3898..035fa475 100644 --- a/tests/test_galaxy.py +++ b/tests/test_galaxy.py @@ -3,11 +3,13 @@ import pytest from fastapi import HTTPException +from freezegun import freeze_time from jose import jwt from pydantic import ValidationError import register from register.tokens import verify_registration_token +from schemas.biocommons import BiocommonsRegisterData from schemas.galaxy import GalaxyRegistrationData from tests.datagen import AccessTokenPayloadFactory, GalaxyRegistrationDataFactory @@ -58,7 +60,7 @@ def test_registration_token_invalid_purpose(mock_settings): verify_registration_token(token, mock_settings) -def test_to_auth0_create_user_data_valid(): +def test_to_biocommons_register_data(): """ Test we can convert GalaxyRegistrationData to the data expected by Auth0 """ @@ -69,7 +71,7 @@ def test_to_auth0_create_user_data_valid(): public_name="valid_username" ) - auth0_data = data.to_auth0_create_user_data() + auth0_data = BiocommonsRegisterData.from_galaxy_registration(data) assert auth0_data.email == "user@example.com" assert auth0_data.password == "securepassword" @@ -78,6 +80,7 @@ def test_to_auth0_create_user_data_valid(): assert auth0_data.user_metadata.galaxy_username == "valid_username" +@freeze_time("2025-01-01") def test_register(mocker, mock_auth_token, mock_settings, test_client): """ Try to test our register endpoint. Since we don't want to call @@ -101,9 +104,10 @@ def test_register(mocker, mock_auth_token, mock_settings, test_client): url = f"https://{mock_settings.auth0_domain}/api/v2/users" headers = {"Authorization": "Bearer mock_token"} + register_data = BiocommonsRegisterData.from_galaxy_registration(user_data) mock_post.assert_called_once_with( url, - json=user_data.to_auth0_create_user_data().model_dump(), + json=register_data.model_dump(), headers=headers ) diff --git a/tests/test_models.py b/tests/test_models.py index 632b2eec..80932f70 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,7 +32,7 @@ def test_approve_service(frozen_time): def test_approve_service_from_app_metadata(frozen_time): """ - Test we can approve a service by ID from AppMetadata. + Test we can approve a service by ID from BiocommonsAppMetadata. """ service = Service(name="Test Service", id="service1", status="pending", last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") @@ -60,7 +60,7 @@ def test_revoke_service(frozen_time): def test_revoke_service_from_app_metadata(frozen_time): """ - Test we can revoke a service by ID from AppMetadata. + Test we can revoke a service by ID from BiocommonsAppMetadata. """ service = Service(name="Test Service", id="service1", status="approved", last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") diff --git a/tests/test_user.py b/tests/test_user.py index 8cc22be6..2603146d 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -3,8 +3,9 @@ import pytest from fastapi import HTTPException -from schemas.service import AppMetadata, Group, Resource, Service -from tests.datagen import AccessTokenPayloadFactory, Auth0UserFactory +from schemas.biocommons import BiocommonsAppMetadata +from schemas.service import Group, Resource, Service +from tests.datagen import AccessTokenPayloadFactory, BiocommonsAuth0UserFactory # --- Test Fixtures --- @@ -29,8 +30,8 @@ def auth_headers(): @pytest.fixture def mock_user_data(): """Fixture to provide mock user data""" - return Auth0UserFactory.build( - app_metadata=AppMetadata( + return BiocommonsAuth0UserFactory.build( + app_metadata=BiocommonsAppMetadata( groups=[Group(name="Australian University", id="AU")], services=[ Service( @@ -172,8 +173,8 @@ def test_get_services_empty_metadata( mock_auth_token, auth_headers, mocker, test_client ): """Test handling of empty metadata""" - empty_user = Auth0UserFactory.build( - app_metadata=AppMetadata(services=[], groups=[]), + empty_user = BiocommonsAuth0UserFactory.build( + app_metadata=BiocommonsAppMetadata(services=[], groups=[]), ) mocker.patch("routers.user.get_user_data", return_value=empty_user) mocker.patch( @@ -190,7 +191,7 @@ def test_get_services_no_metadata( test_client ): """Test handling of missing metadata""" - no_metadata_user = Auth0UserFactory.build(app_metadata=AppMetadata()) + no_metadata_user = BiocommonsAuth0UserFactory.build(app_metadata=BiocommonsAppMetadata()) mocker.patch("routers.user.get_user_data", return_value=no_metadata_user) mocker.patch( "routers.user.get_management_token", return_value="mock_management_token" @@ -253,8 +254,8 @@ def test_get_resources_empty_metadata( mock_auth_token, auth_headers, mocker, test_client ): """Test handling of empty resource metadata""" - empty_user = Auth0UserFactory.build(app_metadata=AppMetadata(services=[], groups=[]), - ) + empty_user = BiocommonsAuth0UserFactory.build(app_metadata=BiocommonsAppMetadata(services=[], groups=[]), + ) mocker.patch("routers.user.get_user_data", return_value=empty_user) mocker.patch( "routers.user.get_management_token", return_value="mock_management_token" @@ -270,7 +271,7 @@ def test_get_resources_no_metadata( test_client ): """Test handling of missing resource metadata""" - no_metadata_user = Auth0UserFactory.build(app_metadata=AppMetadata()) + no_metadata_user = BiocommonsAuth0UserFactory.build(app_metadata=BiocommonsAppMetadata()) mocker.patch("routers.user.get_user_data", return_value=no_metadata_user) mocker.patch( "routers.user.get_management_token", return_value="mock_management_token" diff --git a/tests/test_user_schema.py b/tests/test_user_schema.py index 16047fa2..e818188f 100644 --- a/tests/test_user_schema.py +++ b/tests/test_user_schema.py @@ -1,26 +1,26 @@ -from schemas.user import User +from schemas.user import SessionUser from tests.datagen import AccessTokenPayloadFactory def test_is_admin_true(mock_settings): payload = AccessTokenPayloadFactory.build(biocommons_roles=["Admin"]) - user = User(access_token=payload) + user = SessionUser(access_token=payload) assert user.is_admin(settings=mock_settings) is True def test_is_admin_false(mock_settings): payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"]) - user = User(access_token=payload) + user = SessionUser(access_token=payload) assert user.is_admin(settings=mock_settings) is False def test_is_admin_empty_roles(mock_settings): payload = AccessTokenPayloadFactory.build(biocommons_roles=[]) - user = User(access_token=payload) + user = SessionUser(access_token=payload) assert user.is_admin(settings=mock_settings) is False def test_is_admin_multiple_roles_with_admin(mock_settings): payload = AccessTokenPayloadFactory.build(biocommons_roles=["Admin", "Editor"]) - user = User(access_token=payload) + user = SessionUser(access_token=payload) assert user.is_admin(settings=mock_settings) is True