diff --git a/auth/validator.py b/auth/validator.py index 20c6e130..d3cac112 100644 --- a/auth/validator.py +++ b/auth/validator.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Annotated import httpx @@ -49,10 +50,15 @@ def verify_jwt(token: str, settings: Settings) -> AccessTokenPayload: return AccessTokenPayload(**payload) -def get_rsa_key(token: str, settings: Settings) -> jwk.RSAKey | None: # type: ignore - jwks_url = f"https://{settings.auth0_domain}/.well-known/jwks.json" +@lru_cache(maxsize=100) +def _fetch_rsa_keys(auth0_domain: str) -> dict: + jwks_url = f"https://{auth0_domain}/.well-known/jwks.json" response = httpx.get(jwks_url) - jwks = response.json() + return response.json() + + +def get_rsa_key(token: str, settings: Settings) -> jwk.RSAKey | None: # type: ignore + jwks = _fetch_rsa_keys(settings.auth0_domain) unverified_header = jwt.get_unverified_header(token) for key in jwks["keys"]: diff --git a/db/setup.py b/db/setup.py index 72e1923a..bd55539d 100644 --- a/db/setup.py +++ b/db/setup.py @@ -70,4 +70,7 @@ def create_db_and_tables(): def get_db_session(): engine = get_engine() with Session(engine) as session: - yield session + try: + yield session + finally: + session.close() diff --git a/routers/admin.py b/routers/admin.py index 67a21477..75de7d94 100644 --- a/routers/admin.py +++ b/routers/admin.py @@ -1,4 +1,3 @@ -import asyncio import logging from datetime import datetime from typing import Annotated @@ -6,10 +5,10 @@ from fastapi import APIRouter, Depends, HTTPException, Path from fastapi.params import Query from pydantic import BaseModel, Field, ValidationError -from sqlalchemy import func, or_ +from sqlalchemy import false, func, or_ from sqlmodel import Session, select -from auth.validator import get_current_user, user_is_admin +from auth.validator import user_is_admin from auth0.client import Auth0Client, get_auth0_client from db.models import ( BiocommonsGroup, @@ -19,10 +18,8 @@ PlatformMembership, ) from db.setup import get_db_session -from db.types import GroupEnum -from routers.user import update_user_metadata -from schemas.biocommons import Auth0UserData, Auth0UserDataWithMemberships -from schemas.user import SessionUser +from db.types import ApprovalStatusEnum, GroupEnum +from schemas.biocommons import Auth0UserDataWithMemberships logger = logging.getLogger('uvicorn.error') @@ -50,6 +47,7 @@ class BiocommonsUserResponse(BaseModel): id: str = Field(description="Auth0 user ID") email: str = Field(description="User email address") username: str = Field(description="User username") + email_verified: bool = Field(description="User email verification status") created_at: datetime = Field(description="User creation timestamp") @@ -166,47 +164,75 @@ def get_users(db_session: Annotated[Session, Depends(get_db_session)], # NOTE: This must appear before /users/{user_id} so it takes precedence -@router.get("/users/approved") -def get_approved_users(client: Annotated[Auth0Client, Depends(get_auth0_client)], +@router.get( + "/users/approved", + response_model=list[BiocommonsUserResponse]) +def get_approved_users(db_session: Annotated[Session, Depends(get_db_session)], pagination: Annotated[PaginationParams, Depends(get_pagination_params)]): - resp = client.get_approved_users(page=pagination.page, per_page=pagination.per_page) - return resp + platform_approved_query = ( + select(BiocommonsUser) + .join(PlatformMembership, BiocommonsUser.id == PlatformMembership.user_id) + .where(PlatformMembership.approval_status == ApprovalStatusEnum.APPROVED) + .distinct() + ) + user_query = platform_approved_query.offset(pagination.start_index).limit(pagination.per_page) + users = db_session.exec(user_query).all() + return users -@router.get("/users/pending") -def get_pending_users(client: Annotated[Auth0Client, Depends(get_auth0_client)], +@router.get("/users/pending", + response_model=list[BiocommonsUserResponse]) +def get_pending_users(db_session: Annotated[Session, Depends(get_db_session)], pagination: Annotated[PaginationParams, Depends(get_pagination_params)]): - resp = client.get_pending_users(page=pagination.page, per_page=pagination.per_page) - return resp + platform_pending_query = ( + select(BiocommonsUser) + .join(PlatformMembership, BiocommonsUser.id == PlatformMembership.user_id) + .where(PlatformMembership.approval_status == ApprovalStatusEnum.PENDING) + .distinct() + ) + user_query = platform_pending_query.offset(pagination.start_index).limit(pagination.per_page) + users = db_session.exec(user_query).all() + return users @router.get("/users/revoked") -def get_revoked_users(client: Annotated[Auth0Client, Depends(get_auth0_client)], +def get_revoked_users(db_session: Annotated[Session, Depends(get_db_session)], pagination: Annotated[PaginationParams, Depends(get_pagination_params)]): - resp = client.get_revoked_users(page=pagination.page, per_page=pagination.per_page) - return resp + platform_revoked_query = ( + select(BiocommonsUser) + .join(PlatformMembership, BiocommonsUser.id == PlatformMembership.user_id) + .where(PlatformMembership.approval_status == ApprovalStatusEnum.REVOKED) + .distinct() + ) + user_query = platform_revoked_query.offset(pagination.start_index).limit(pagination.per_page) + users = db_session.exec(user_query).all() + return users -@router.get("/users/unverified", response_model=list[Auth0UserData]) +@router.get("/users/unverified", response_model=list[BiocommonsUserResponse]) def get_unverified_users( - client: Annotated[Auth0Client, Depends(get_auth0_client)], + db_session: Annotated[Session, Depends(get_db_session)], pagination: Annotated[PaginationParams, Depends(get_pagination_params)], ): """ Return users whose email is not verified, using Auth0 search for efficiency. """ - return client.get_users( - page=pagination.page, - per_page=pagination.per_page, - q="email_verified:false", + query = ( + select(BiocommonsUser) + .where(BiocommonsUser.email_verified == false()) + .offset(pagination.start_index) + .limit(pagination.per_page) ) + users = db_session.exec(query).all() + return users @router.get("/users/{user_id}", - response_model=Auth0UserData) + response_model=BiocommonsUserResponse) def get_user(user_id: Annotated[str, UserIdParam], - client: Annotated[Auth0Client, Depends(get_auth0_client)]): - return client.get_user(user_id) + db_session: Annotated[Session, Depends(get_db_session)]): + user = db_session.get_one(BiocommonsUser, user_id) + return user @router.get("/users/{user_id}/details", @@ -233,72 +259,3 @@ def resend_verification_email(user_id: Annotated[str, UserIdParam], client: Annotated[Auth0Client, Depends(get_auth0_client)]): client.resend_verification_email(user_id) return {"message": "Verification email resent."} - - -@router.post("/users/{user_id}/services/{service_id}/approve") -def approve_service(user_id: Annotated[str, UserIdParam], - service_id: Annotated[str, ServiceIdParam], - client: Annotated[Auth0Client, Depends(get_auth0_client)], - approving_user: Annotated[SessionUser, Depends(get_current_user)]): - user = client.get_user(user_id=user_id) - # Need to fetch full user info currently to get email address, not in access token - approving_user_data = client.get_user(user_id=approving_user.access_token.sub) - logger.debug(f"Approving service {service_id} for user {user_id} by {approving_user_data.email}") - user.app_metadata.approve_service(service_id, updated_by=str(approving_user_data.email)) - logger.info("Sending updated metadata to Auth0 API") - # update_user_metadata is async, so run via asyncio - update = update_user_metadata( - user_id=user_id, - token=client.management_token, - metadata=user.app_metadata.model_dump(mode="json") - ) - resp = asyncio.run(update) - logger.info("Metadata updated successfully") - return resp - - -@router.post("/users/{user_id}/services/{service_id}/revoke") -def revoke_service(user_id: Annotated[str, UserIdParam], - service_id: Annotated[str, ServiceIdParam], - client: Annotated[Auth0Client, Depends(get_auth0_client)], - revoking_user: Annotated[SessionUser, Depends(get_current_user)]): - """ - Revoke a service and all associated resources for a user. - """ - user = client.get_user(user_id=user_id) - revoking_user_data = client.get_user(user_id=revoking_user.access_token.sub) - user.app_metadata.revoke_service(service_id=service_id, updated_by=str(revoking_user_data.email)) - service = user.app_metadata.get_service_by_id(service_id) - for resource in service.resources: - resource.revoke() - update = update_user_metadata( - user_id=user_id, - token=client.management_token, - metadata=user.app_metadata.model_dump(mode="json") - ) - resp = asyncio.run(update) - return resp - - -@router.post("/users/{user_id}/services/{service_id}/resources/{resource_id}/approve") -def approve_resource(user_id: Annotated[str, UserIdParam], - service_id: Annotated[str, ServiceIdParam], - resource_id: Annotated[str, ResourceIdParam], - client: Annotated[Auth0Client, Depends(get_auth0_client)], - approving_user: Annotated[SessionUser, Depends(get_current_user)]): - user = client.get_user(user_id=user_id) - approving_user_data = client.get_user(user_id=approving_user.access_token.sub) - - user.app_metadata.approve_resource( - service_id=service_id, - resource_id=resource_id, - updated_by=approving_user_data.email - ) - - update = update_user_metadata( - user_id=user_id, - token=client.management_token, - metadata=user.app_metadata.model_dump(mode="json") - ) - resp = asyncio.run(update) - return resp diff --git a/routers/bpa_register.py b/routers/bpa_register.py index 440fa4ba..f19c251d 100644 --- a/routers/bpa_register.py +++ b/routers/bpa_register.py @@ -1,5 +1,4 @@ import logging -from datetime import datetime, timezone from fastapi import APIRouter, Depends, HTTPException from httpx import HTTPStatusError @@ -7,14 +6,12 @@ from starlette.responses import JSONResponse from auth0.client import Auth0Client, get_auth0_client -from config import Settings, get_settings from db.models import BiocommonsUser, PlatformEnum from db.setup import get_db_session from routers.errors import RegistrationRoute from schemas.biocommons import Auth0UserData, BiocommonsRegisterData from schemas.bpa import BPARegistrationRequest from schemas.responses import RegistrationErrorResponse, RegistrationResponse -from schemas.service import Service logger = logging.getLogger(__name__) @@ -26,17 +23,6 @@ ) -def _get_bpa_service_request(registration: BPARegistrationRequest, settings: Settings, update_time: datetime) -> Service: - return Service( - name="Bioplatforms Australia Data Portal", - id="bpa", - initial_request_time=update_time, - status="pending", - last_updated=update_time, - updated_by="system", - ) - - @router.post( "/register", responses={ @@ -46,17 +32,13 @@ def _get_bpa_service_request(registration: BPARegistrationRequest, settings: Set ) async def register_bpa_user( registration: BPARegistrationRequest, - settings: Settings = Depends(get_settings), db_session: Session = Depends(get_db_session), auth0_client: Auth0Client = Depends(get_auth0_client) ): """Register a new BPA user.""" - now = datetime.now(timezone.utc) - bpa_service = _get_bpa_service_request(registration=registration, settings=settings, update_time=now) - # Create Auth0 user data user_data = BiocommonsRegisterData.from_bpa_registration( - registration=registration, bpa_service=bpa_service + registration=registration ) try: diff --git a/routers/user.py b/routers/user.py index 3a41e041..e4bad50e 100644 --- a/routers/user.py +++ b/routers/user.py @@ -1,15 +1,19 @@ -from datetime import datetime, timezone -from typing import Annotated, Any, Dict, List +from typing import Annotated, Any, Dict from fastapi import APIRouter, Depends, HTTPException from httpx import AsyncClient +from pydantic import AliasPath, Field +from pydantic import BaseModel as PydanticBaseModel +from sqlmodel import Session, select +from sqlmodel.sql._expression_select_cls import SelectOfScalar from auth.management import get_management_token from auth.validator import get_current_user from config import Settings, get_settings +from db.models import GroupMembership, PlatformMembership +from db.setup import get_db_session +from db.types import ApprovalStatusEnum from schemas.biocommons import Auth0UserData -from schemas.requests import ResourceRequest, ServiceRequest -from schemas.service import Resource, Service from schemas.user import SessionUser router = APIRouter( @@ -17,6 +21,28 @@ ) +class PlatformMembershipData(PydanticBaseModel): + platform_id: str + approval_status: str + + +class GroupMembershipData(PydanticBaseModel): + """ + Data model for group membership, when returned from the API. + Should be created automatically from GroupMembership when + setting a response_model on a route. + """ + group_id: str + approval_status: str + # Get group_name from the nested group object + group_name: str = Field(validation_alias=AliasPath("group", "name")) + + +class CombinedMembershipData(PydanticBaseModel): + platforms: list[PlatformMembershipData] + groups: list[GroupMembershipData] + + async def get_user_data( user: SessionUser, settings: Annotated[Settings, Depends(get_settings)] ) -> Auth0UserData: @@ -71,205 +97,116 @@ async def update_user_metadata( ) -@router.get("/services", response_model=Dict[str, List[Service]]) -async def get_services( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], -): - """Get all services for the authenticated user.""" - user_data = await get_user_data(user, settings) - return {"services": user_data.app_metadata.services} +def _get_user_platforms(user_id: str, + approval_status: ApprovalStatusEnum | None = None) -> SelectOfScalar[PlatformMembership]: + """Utility function to get platforms for a user.""" + query = (select(PlatformMembership) + .where(PlatformMembership.user_id == user_id)) + if approval_status is not None: + query = query.where(PlatformMembership.approval_status == approval_status) + return query -@router.get("/is-admin") -async def check_is_admin( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], -): - """Check if the current user has admin privileges.""" - return {"is_admin": user.is_admin(settings)} +def _get_user_groups(user_id: str, + approval_status: ApprovalStatusEnum | None = None) -> SelectOfScalar[GroupMembership]: + """Utility function to get groups for a user.""" + query = (select(GroupMembership) + .where(GroupMembership.user_id == user_id)) + if approval_status is not None: + query = query.where(GroupMembership.approval_status == approval_status) + return query -@router.get("/services/approved", response_model=Dict[str, List[Service]]) -async def get_approved_services( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], +@router.get("/platforms", + response_model=list[PlatformMembershipData],) +async def get_platforms( + user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], ): - """Get approved services for the authenticated user.""" - user_data = await get_user_data(user, settings) - return {"approved_services": user_data.approved_services} + query = _get_user_platforms(user_id=user.access_token.sub) + return db_session.exec(query).all() -@router.get("/services/pending", response_model=Dict[str, List[Service]]) -async def get_pending_services( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], +@router.get( + "/platforms/approved", + response_model=list[PlatformMembershipData], +) +async def get_approved_platforms( + user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], ): - """Get pending services for the authenticated user.""" - user_data = await get_user_data(user, settings) - return {"pending_services": user_data.pending_services} + """Get approved platforms for the current user.""" + query = _get_user_platforms(user_id=user.access_token.sub, + approval_status=ApprovalStatusEnum.APPROVED) + return db_session.exec(query).all() -@router.get("/resources", response_model=Dict[str, List[Resource]]) -async def get_resources( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], +@router.get( + "/platforms/pending", + response_model=list[PlatformMembershipData], +) +async def get_pending_platforms( + user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], ): - """Get all resources for the authenticated user.""" - user_data = await get_user_data(user, settings) - return {"resources": user_data.app_metadata.get_all_resources()} + """Get pending platforms for the current user.""" + query = _get_user_platforms(user_id=user.access_token.sub, + approval_status=ApprovalStatusEnum.PENDING) + return db_session.exec(query).all() -@router.get("/resources/approved", response_model=Dict[str, List[Resource]]) -async def get_approved_resources( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], +@router.get("/groups", + response_model=list[GroupMembershipData],) +async def get_groups( + user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], ): - """Get approved resources for the authenticated user.""" - user_data = await get_user_data(user, settings) - return {"approved_resources": user_data.approved_resources} + query = _get_user_groups(user_id=user.access_token.sub) + return db_session.exec(query).all() -@router.get("/resources/pending", response_model=Dict[str, List[Resource]]) -async def get_pending_resources( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], +@router.get("/groups/approved", + response_model=list[GroupMembershipData],) +async def get_approved_groups( + user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], ): - """Get pending resources for the authenticated user.""" - user_data = await get_user_data(user, settings) - return {"pending_resources": user_data.pending_resources} + query = _get_user_groups(user_id=user.access_token.sub, + approval_status=ApprovalStatusEnum.APPROVED) + return db_session.exec(query).all() -@router.get("/all/pending", response_model=Dict[str, List[Any]]) -async def get_all_pending( - user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], +@router.get("/groups/pending", + response_model=list[GroupMembershipData],) +async def get_pending_groups( + user: Annotated[SessionUser, Depends(get_current_user)], + db_session: Annotated[Session, Depends(get_db_session)], ): - """Get all pending services and resources.""" - user_data = await get_user_data(user, settings) - return { - "pending_services": user_data.pending_services, - "pending_resources": user_data.pending_resources, - } + query = _get_user_groups(user_id=user.access_token.sub, + approval_status=ApprovalStatusEnum.PENDING) + return db_session.exec(query).all() -@router.post( - "/request/service", - response_model=Dict[str, Any], - responses={ - 400: {"description": "Bad Request - Service already exists"}, - 403: {"description": "Forbidden - User ID mismatch"}, - 500: {"description": "Internal server error"}, - }, -) -async def request_service( - service_request: ServiceRequest, +@router.get("/is-admin") +async def check_is_admin( user: Annotated[SessionUser, Depends(get_current_user)], settings: Annotated[Settings, Depends(get_settings)], -) -> Dict[str, Any]: - """Submit a request for a service.""" - if user.access_token.sub != service_request.user_id: - raise HTTPException( - status_code=403, - detail="User ID in request does not match authenticated user", - ) - - user_data = await get_user_data(user, settings=settings) - - if any(s.id == service_request.id for s in user_data.app_metadata.services): - raise HTTPException( - status_code=400, - detail=f"Service request with ID {service_request.id} already exists", - ) - - new_service = Service( - name=service_request.name, - id=service_request.id, - initial_request_time=datetime.now(timezone.utc), - status="pending", - last_updated=datetime.now(timezone.utc), - updated_by=user.access_token.sub, - resources=[], - ) - - user_data.app_metadata.services.append(new_service) - await update_user_metadata( - user.access_token.sub, - get_management_token(settings=settings), - user_data.app_metadata.model_dump(mode="json"), - ) - - return { - "message": "Service request submitted successfully", - "service": new_service.model_dump(mode="json"), - } +): + """Check if the current user has admin privileges.""" + return {"is_admin": user.is_admin(settings)} -@router.post( - "/request/{service_id}/{resource_id}", - response_model=Dict[str, Any], - responses={ - 400: {"description": "Bad Request"}, - 403: {"description": "Forbidden"}, - 404: {"description": "Service not found"}, - 500: {"description": "Internal server error"}, - }, -) -async def request_resource( - service_id: str, - resource_id: str, - resource_request: ResourceRequest, +@router.get("/all/pending", + response_model=CombinedMembershipData) +async def get_all_pending( user: Annotated[SessionUser, Depends(get_current_user)], - settings: Annotated[Settings, Depends(get_settings)], -) -> Dict[str, Any]: - """Submit a request for a resource within a service.""" - if user.access_token.sub != resource_request.user_id: - raise HTTPException( - status_code=403, - detail="User ID in request does not match authenticated user", - ) - - if service_id != resource_request.service_id: - raise HTTPException( - status_code=400, detail="Service ID in path does not match request body" - ) - - user_data = await get_user_data(user, settings) - service = user_data.app_metadata.get_service_by_id(service_id) - - if not service: - raise HTTPException( - status_code=404, detail=f"Service with ID {service_id} not found" - ) - - if service.status != "approved": - raise HTTPException( - status_code=400, - detail="Cannot request resources for a service that is not approved", - ) - - if any(r.id == resource_id for r in service.resources): - raise HTTPException( - status_code=400, - detail=f"Resource request with ID {resource_id} already exists", - ) - - new_resource = Resource( - name=resource_request.name, id=resource_id, status="pending" - ) - - service.resources.append(new_resource) - service.last_updated = datetime.now(timezone.utc) - service.updated_by = user.access_token.sub - - await update_user_metadata( - user.access_token.sub, - get_management_token(settings=settings), - user_data.app_metadata.model_dump(mode="json"), - ) - - return { - "message": "Resource request submitted successfully", - "service": service.model_dump(mode="json"), - "resource": new_resource.model_dump(mode="json"), - } + db_session: Annotated[Session, Depends(get_db_session)], +): + """Get all pending platforms and groups.""" + platforms_query = _get_user_platforms(user_id=user.access_token.sub, + approval_status=ApprovalStatusEnum.PENDING) + groups_query = _get_user_groups(user_id=user.access_token.sub, + approval_status=ApprovalStatusEnum.PENDING) + platforms = db_session.exec(platforms_query).all() + groups = db_session.exec(groups_query).all() + return {"platforms": platforms, "groups": groups} diff --git a/schemas/__init__.py b/schemas/__init__.py index 5a03bd77..e69de29b 100644 --- a/schemas/__init__.py +++ b/schemas/__init__.py @@ -1,4 +0,0 @@ -from schemas.group import Group -from schemas.service import Resource, Service - -__all__ = ["Service", "Resource", "Group"] diff --git a/schemas/biocommons.py b/schemas/biocommons.py index e74812c4..c3dbf8e7 100644 --- a/schemas/biocommons.py +++ b/schemas/biocommons.py @@ -6,7 +6,7 @@ """ import re -from datetime import datetime, timezone +from datetime import datetime from typing import Annotated, List, Literal, Optional, Self from pydantic import ( @@ -21,8 +21,6 @@ import db import schemas from db.types import GroupMembershipData, PlatformMembershipData -from schemas import Resource, Service -from schemas.service import Group, Identity # From Auth0 password settings ALLOWED_SPECIAL_CHARS = "!@#$%^&*" @@ -96,75 +94,15 @@ class BiocommonsUserMetadata(BaseModel): class BiocommonsAppMetadata(BaseModel): """ - app_metadata we use to manage service/resource requests. + app_metadata we use to store Auth0-specific info Note we expect all app_metadata from Auth0 to match this format (if not empty). """ - - groups: List[Group] = Field(default_factory=list) - services: List[Service] = Field(default_factory=list) registration_from: Optional[AppId] = None - def get_pending_services(self) -> List[Service]: - """Get all pending services.""" - return [s for s in self.services if s.status == "pending"] - - def get_approved_services(self) -> List[Service]: - """Get all approved services.""" - return [s for s in self.services if s.status == "approved"] - - def get_all_resources(self) -> List[Resource]: - """Get all resources across services.""" - return [r for s in self.services for r in s.resources] - - def get_pending_resources(self) -> List[Resource]: - """Get all pending resources.""" - return [r for s in self.services for r in s.resources if r.status == "pending"] - - def get_approved_resources(self) -> List[Resource]: - """Get all approved resources.""" - return [r for s in self.services for r in s.resources if r.status == "approved"] - - def get_service_by_id(self, service_id: str) -> Optional[Service]: - """Get a service by its ID.""" - return next((s for s in self.services if s.id == service_id), None) - - def get_resource_by_id( - self, service_id: str, resource_id: str - ) -> Optional[Resource]: - """Get a resource by its ID.""" - service = self.get_service_by_id(service_id) - if service: - return service.get_resource_by_id(resource_id) - else: - return None - - def approve_service(self, service_id: str, updated_by: str): - """Approve a service by its ID.""" - service = self.get_service_by_id(service_id) - if service: - service.approve(updated_by) - - def revoke_service(self, service_id: str, updated_by: str): - """Revoke a service by its ID.""" - service = self.get_service_by_id(service_id) - if service: - service.revoke(updated_by=updated_by) - - def approve_resource(self, service_id: str, resource_id: str, updated_by: str): - service = self.get_service_by_id(service_id) - if not service: - raise ValueError(f"Service '{service_id}' not found.") - - resource = service.get_resource_by_id(resource_id) - if not resource: - raise ValueError( - f"Resource '{resource_id}' not found in service '{service_id}'." - ) - - resource.status = "approved" - resource.last_updated = datetime.now(timezone.utc) - resource.updated_by = updated_by + model_config = { + "extra": "ignore" + } class BiocommonsRegisterData(BaseModel): @@ -190,7 +128,7 @@ def model_dump(self, **kwargs): @classmethod def from_bpa_registration( - cls, registration: "schemas.bpa.BPARegistrationRequest", bpa_service: Service + cls, registration: "schemas.bpa.BPARegistrationRequest" ) -> Self: return cls( email=registration.email, @@ -201,7 +139,7 @@ def from_bpa_registration( bpa=BPAMetadata(registration_reason=registration.reason), ), app_metadata=BiocommonsAppMetadata( - services=[bpa_service], registration_from="bpa" + registration_from="bpa" ), ) @@ -210,15 +148,6 @@ def from_galaxy_registration( cls, registration: "schemas.galaxy.GalaxyRegistrationData", ): - # Galaxy registration is approved automatically - galaxy_service = Service( - name="Galaxy Australia", - id="galaxy", - initial_request_time=datetime.now(), - status="approved", - last_updated=datetime.now(), - updated_by="", - ) return BiocommonsRegisterData( email=registration.email, username=registration.username, @@ -226,7 +155,7 @@ def from_galaxy_registration( email_verified=False, connection="Username-Password-Authentication", app_metadata=BiocommonsAppMetadata( - services=[galaxy_service], registration_from="galaxy" + registration_from="galaxy" ), ) @@ -248,6 +177,13 @@ def from_biocommons_registration( ) +class Auth0Identity(BaseModel): + connection: str + provider: str + user_id: str + isSocial: bool + + class Auth0UserData(BaseModel): """ Represents the user data we get back from Auth0 for Biocommons users @@ -258,7 +194,7 @@ class Auth0UserData(BaseModel): email: EmailStr username: Optional[BiocommonsUsername] = None email_verified: bool - identities: List[Identity] + identities: List[Auth0Identity] name: str nickname: str picture: HttpUrl @@ -272,26 +208,6 @@ class Auth0UserData(BaseModel): last_login: Optional[datetime] = None logins_count: Optional[int] = None - @property - def pending_services(self) -> List[Service]: - """Get all services with pending status.""" - return self.app_metadata.get_pending_services() - - @property - def approved_services(self) -> List[Service]: - """Get all services with approved status.""" - return self.app_metadata.get_approved_services() - - @property - def pending_resources(self) -> List[Resource]: - """Get all resources with pending status across all services.""" - return self.app_metadata.get_pending_resources() - - @property - def approved_resources(self) -> List[Resource]: - """Get all resources with approved status across all services.""" - return self.app_metadata.get_approved_resources() - class Auth0UserDataWithMemberships(Auth0UserData): """ diff --git a/schemas/requests.py b/schemas/requests.py deleted file mode 100644 index d929c9aa..00000000 --- a/schemas/requests.py +++ /dev/null @@ -1,14 +0,0 @@ -from pydantic import BaseModel - - -class ServiceRequest(BaseModel): - name: str - id: str - user_id: str - - -class ResourceRequest(BaseModel): - name: str - id: str - service_id: str - user_id: str diff --git a/schemas/service.py b/schemas/service.py deleted file mode 100644 index 4d24c97e..00000000 --- a/schemas/service.py +++ /dev/null @@ -1,65 +0,0 @@ -from datetime import datetime -from typing import List, Literal, Optional - -from pydantic import BaseModel, Field - - -class Resource(BaseModel): - name: str - status: Literal["approved", "revoked", "pending"] - id: str - last_updated: Optional[datetime] = None - initial_request_time: Optional[datetime] = None - updated_by: Optional[str] = None - - def approve(self): - self.status = "approved" - - def revoke(self): - self.status = "revoked" - - -class Service(BaseModel): - name: str - id: str - initial_request_time: Optional[datetime] = None - status: Literal["approved", "revoked", "pending"] - last_updated: datetime - updated_by: str - resources: List[Resource] = Field(default_factory=list) - - def approve(self, updated_by: str): - self.status = "approved" - self.updated_by = updated_by - self.last_updated = datetime.now() - - def revoke(self, updated_by: str): - self.status = "revoked" - self.updated_by = updated_by - self.last_updated = datetime.now() - - def approve_resource(self, resource_id: str): - if not self.status == "approved": - raise PermissionError("Service must be approved before approving a resource.") - resource = self.get_resource_by_id(resource_id) - if resource: - resource.approve() - self.last_updated = datetime.now() - return resource - else: - raise ValueError("Resource not found.") - - def get_resource_by_id(self, resource_id: str) -> Optional[Resource]: - return next((r for r in self.resources if r.id == resource_id), None) - - -class Group(BaseModel): - name: str - id: str - - -class Identity(BaseModel): - connection: str - provider: str - user_id: str - isSocial: bool diff --git a/tests/conftest.py b/tests/conftest.py index 2bc2ad5e..d7e9fe17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,9 +62,11 @@ def test_db_session(session): def get_db_session_override(): yield session app.dependency_overrides[get_db_session] = get_db_session_override - yield session - app.dependency_overrides.clear() - session.close() + try: + yield session + finally: + app.dependency_overrides.clear() + session.close() @pytest.fixture(autouse=True) diff --git a/tests/test_admin.py b/tests/test_admin.py index 352964bb..53fbb46f 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -1,5 +1,5 @@ import asyncio -from datetime import datetime, timedelta +from datetime import datetime import pytest from fastapi import HTTPException @@ -8,12 +8,11 @@ from auth.management import get_management_token from auth.validator import get_current_user, user_is_admin from auth0.client import Auth0Client +from db.types import ApprovalStatusEnum, PlatformEnum from main import app from routers.admin import PaginationParams -from schemas import Resource, Service from tests.datagen import ( AccessTokenPayloadFactory, - AppMetadataFactory, Auth0UserDataFactory, EmailVerificationResponseFactory, SessionUserFactory, @@ -46,7 +45,7 @@ def test_pagination_params_start_index(): assert params.start_index == 10 -def test_get_users_requires_admin_unauthorized(test_client, mocker): +def test_get_users_requires_admin_unauthorized(test_client): def get_nonadmin_user(): payload = AccessTokenPayloadFactory.build(biocommons_roles=["User"]) return SessionUserFactory.build(access_token=payload) @@ -356,48 +355,48 @@ def test_get_filter_options(test_client, as_admin_user): assert option_dict["bpa_galaxy"] == "Bioplatforms Australia Data Portal & Galaxy Australia Bundle" -def test_get_user(test_client, as_admin_user, mock_auth0_client): - user = Auth0UserDataFactory.build() - mock_auth0_client.get_user.return_value = user - resp = test_client.get(f"/admin/users/{user.user_id}") +def test_get_user(test_client, test_db_session, as_admin_user, persistent_factories): + user = BiocommonsUserFactory.create_sync() + resp = test_client.get(f"/admin/users/{user.id}") assert resp.status_code == 200 assert resp.json() == user.model_dump(mode='json') -def test_get_approved_users(test_client, as_admin_user, mock_auth0_client): - approved_users = Auth0UserDataFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "approved"}]}) - mock_auth0_client.get_approved_users.return_value = approved_users +def test_get_approved_users(test_client, test_db_session, as_admin_user, persistent_factories): + approved_users = BiocommonsUserFactory.create_batch_sync(3) + for u in approved_users: + u.add_platform_membership(platform=PlatformEnum.GALAXY, db_session=test_db_session, auto_approve=True) resp = test_client.get("/admin/users/approved") assert resp.status_code == 200 assert len(resp.json()) == 3 - approved_ids = set(u.user_id for u in approved_users) + approved_ids = set(u.id for u in approved_users) for returned_user in resp.json(): - assert returned_user["app_metadata"]["services"][0]["status"] == "approved" - assert returned_user["user_id"] in approved_ids + assert returned_user["id"] in approved_ids -def test_get_pending_users(test_client, as_admin_user, mock_auth0_client): - pending_users = Auth0UserDataFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "pending"}]}) - mock_auth0_client.get_pending_users.return_value = pending_users +def test_get_pending_users(test_client, test_db_session, as_admin_user, persistent_factories): + pending_users = BiocommonsUserFactory.create_batch_sync(3) + for u in pending_users: + u.add_platform_membership(platform=PlatformEnum.GALAXY, db_session=test_db_session, auto_approve=False) resp = test_client.get("/admin/users/pending") assert resp.status_code == 200 assert len(resp.json()) == 3 - pending_ids = set(u.user_id for u in pending_users) + expected_ids = set(u.id for u in pending_users) for returned_user in resp.json(): - assert returned_user["app_metadata"]["services"][0]["status"] == "pending" - assert returned_user["user_id"] in pending_ids + assert returned_user["id"] in expected_ids -def test_get_revoked(test_client, as_admin_user, mock_auth0_client): - revoked_users = Auth0UserDataFactory.batch(3, app_metadata={"services": [{"name": "BPA", "status": "revoked"}]}) - mock_auth0_client.get_revoked_users.return_value = revoked_users +def test_get_revoked_users(test_client, test_db_session, as_admin_user, persistent_factories): + revoked_users = BiocommonsUserFactory.create_batch_sync(3) + for u in revoked_users: + PlatformMembershipFactory.create_sync(user=u, platform_id=PlatformEnum.GALAXY, approval_status=ApprovalStatusEnum.REVOKED) + test_db_session.commit() resp = test_client.get("/admin/users/revoked") assert resp.status_code == 200 assert len(resp.json()) == 3 - revoked_ids = set(u.user_id for u in revoked_users) + expected_ids = set(u.id for u in revoked_users) for returned_user in resp.json(): - assert returned_user["app_metadata"]["services"][0]["status"] == "revoked" - assert returned_user["user_id"] in revoked_ids + assert returned_user["id"] in expected_ids # Patch asyncio.run to work in the AnyIO worker thread @@ -410,153 +409,6 @@ def run_in_new_loop(coro): loop.close() -def test_approve_service(test_client, as_admin_user, mock_auth0_client, mocker): - """ - Test that our approved service endpoint tries to update the Auth0 user's metadata. - - Note this is currently pretty clunky due to the need to mock out asyncio.run. - """ - # Build test user and metadata - service = Service( - name="Test Service", - id="service1", - status="pending", - last_updated=FROZEN_TIME - timedelta(hours=1), - updated_by="" - ) - app_metadata = AppMetadataFactory.build(services=[service]) - user = Auth0UserDataFactory.build(app_metadata=app_metadata.model_dump(mode="json")) - approving_user = Auth0UserDataFactory.build() - - # Mock Auth0 client behavior - mock_auth0_client.get_user.side_effect = [user, approving_user] - - # Patch update_user_metadata as an AsyncMock - mock_update = mocker.patch( - "routers.admin.update_user_metadata", - new_callable=mocker.AsyncMock, - return_value={"status": "ok", "updated": True} - ) - - mocker.patch("routers.admin.asyncio.run", side_effect=run_in_new_loop) - - # Make the API call - resp = test_client.post(f"/admin/users/{user.user_id}/services/{service.id}/approve") - - # Validate HTTP response - assert resp.status_code == 200 - mock_update.assert_awaited_once() - - # Validate that update_user_metadata was called with correct data - args, kwargs = mock_update.call_args - assert kwargs["user_id"] == user.user_id - assert kwargs["token"] == mock_auth0_client.management_token - assert "services" in kwargs["metadata"] - service_data = kwargs["metadata"]["services"][0] - assert service_data["status"] == "approved" - assert service_data["id"] == service.id - assert service_data["updated_by"] == approving_user.email - - -def test_revoke_service(test_client, as_admin_user, mock_auth0_client, mocker): - """ - Test that our approved service endpoint tries to update the Auth0 user's metadata. - - Note this is currently pretty clunky due to the need to mock out asyncio.run. - """ - resource1 = Resource(name="Test Resource", id="resource1", status="approved") - resource2 = Resource(name="Test Resource", id="resource2", status="approved") - service = Service( - name="Test Service", - id="service1", - status="approved", - last_updated=FROZEN_TIME - timedelta(hours=1), - updated_by="", - resources=[resource1, resource2] - ) - app_metadata = AppMetadataFactory.build(services=[service]) - user = Auth0UserDataFactory.build(app_metadata=app_metadata.model_dump(mode="json")) - revoking_user = Auth0UserDataFactory.build() - - # Mock Auth0 client behavior - mock_auth0_client.get_user.side_effect = [user, revoking_user] - - # Patch update_user_metadata as an AsyncMock - mock_update = mocker.patch( - "routers.admin.update_user_metadata", - new_callable=mocker.AsyncMock, - return_value={"status": "ok", "updated": True} - ) - - mocker.patch("routers.admin.asyncio.run", side_effect=run_in_new_loop) - - # Make the API call - resp = test_client.post(f"/admin/users/{user.user_id}/services/{service.id}/revoke") - - # Validate HTTP response - assert resp.status_code == 200 - mock_update.assert_awaited_once() - - # Validate that update_user_metadata was called with correct data - args, kwargs = mock_update.call_args - assert kwargs["user_id"] == user.user_id - assert kwargs["token"] == mock_auth0_client.management_token - assert "services" in kwargs["metadata"] - service_data = kwargs["metadata"]["services"][0] - assert service_data["status"] == "revoked" - assert service_data["id"] == service.id - assert service_data["updated_by"] == revoking_user.email - for resource in service_data["resources"]: - assert resource["status"] == "revoked" - - -def test_approve_resource(test_client, as_admin_user, mock_auth0_client, mocker): - """ - Test that our approve resource endpoint tries to update the Auth0 user's metadata. - """ - # Build test user and metadata - resource = Resource(name="Test Resource", id="resource1", status="pending") - service = Service( - name="Test Service", - id="service1", - status="approved", - last_updated=FROZEN_TIME - timedelta(hours=1), - resources=[resource], - updated_by="" - ) - app_metadata = AppMetadataFactory.build(services=[service]) - user = Auth0UserDataFactory.build(app_metadata=app_metadata.model_dump(mode="json")) - - # Mock Auth0 client behavior - mock_auth0_client.get_user.return_value = user - - # Patch update_user_metadata as an AsyncMock - mock_update = mocker.patch( - "routers.admin.update_user_metadata", - new_callable=mocker.AsyncMock, - return_value={"status": "ok", "updated": True} - ) - - mocker.patch("routers.admin.asyncio.run", side_effect=run_in_new_loop) - - # Make the API call - resp = test_client.post(f"/admin/users/{user.user_id}/services/{service.id}/resources/{resource.id}/approve") - - # Validate HTTP response - assert resp.status_code == 200 - mock_update.assert_awaited_once() - - # Validate that update_user_metadata was called with correct data - args, kwargs = mock_update.call_args - assert kwargs["user_id"] == user.user_id - assert kwargs["token"] == mock_auth0_client.management_token - assert "services" in kwargs["metadata"] - service_data = kwargs["metadata"]["services"][0] - resource_data = service_data["resources"][0] - assert resource_data["status"] == "approved" - assert resource_data["id"] == resource.id - - def test_resend_verification_email(test_client, as_admin_user, mock_auth0_client): user = Auth0UserDataFactory.build() response_data = EmailVerificationResponseFactory.build() @@ -586,20 +438,13 @@ def test_get_user_details(test_client, test_db_session, as_admin_user, mock_auth assert platforms[0] == platform_membership_data -def test_get_unverified_users(test_client, as_admin_user, mock_auth0_client): - u1 = Auth0UserDataFactory.build(email_verified=False) - u2 = Auth0UserDataFactory.build(email_verified=False) - mock_auth0_client.get_users.return_value = [u1, u2] - - resp = test_client.get("/admin/users/unverified?page=2&per_page=10") +def test_get_unverified_users(test_client, test_db_session, as_admin_user, persistent_factories): + BiocommonsUserFactory.create_batch_sync(2, email_verified=True) + BiocommonsUserFactory.create_batch_sync(3, email_verified=False) + resp = test_client.get("/admin/users/unverified") assert resp.status_code == 200 - - mock_auth0_client.get_users.assert_called_once_with( - page=2, per_page=10, q="email_verified:false" - ) - data = resp.json() - assert len(data) == 2 + assert len(data) == 3 assert all(u["email_verified"] is False for u in data) diff --git a/tests/test_bpa_register.py b/tests/test_bpa_register.py index 7ce0426e..0278db60 100644 --- a/tests/test_bpa_register.py +++ b/tests/test_bpa_register.py @@ -1,5 +1,3 @@ -from datetime import UTC, datetime - import httpx import pytest from sqlmodel import select @@ -10,7 +8,6 @@ PlatformMembership, PlatformMembershipHistory, ) -from schemas import Service from schemas.biocommons import BiocommonsRegisterData from tests.datagen import ( Auth0UserDataFactory, @@ -33,15 +30,8 @@ def valid_registration_data(): def test_to_biocommons_register_data(valid_registration_data): bpa_data = BPARegistrationDataFactory.build() - bpa_service = Service( - name="Bioplatforms Australia", - id="bpa", - status="approved", - last_updated=datetime.now(UTC), - updated_by="system", - ) register_data = BiocommonsRegisterData.from_bpa_registration( - bpa_data, bpa_service=bpa_service + bpa_data ) assert register_data.username == bpa_data.username assert register_data.name == bpa_data.fullname @@ -83,43 +73,13 @@ def test_successful_registration( assert not called_data.email_verified app_metadata = called_data.app_metadata - assert len(app_metadata.services) == 1 - bpa_service = app_metadata.services[0] - assert bpa_service.name == "Bioplatforms Australia Data Portal" - assert bpa_service.status == "pending" - assert bpa_service.last_updated is not None - assert bpa_service.updated_by == "system" - assert len(bpa_service.resources) == 0 + assert app_metadata.registration_from == "bpa" assert ( called_data.user_metadata.bpa.registration_reason == valid_registration_data["reason"] ) -def test_service_and_resources_have_updated_by_system(): - service = Service( - name="Test Service", - id="svc1", - status="pending", - last_updated=datetime.now(UTC), - updated_by="system", - resources=[ - { - "id": "res1", - "name": "Test Resource", - "status": "pending", - "last_updated": datetime.now(UTC), - "updated_by": "system", - "initial_request_time": datetime.now(UTC), - } - ], - ) - assert service.updated_by == "system" - assert service.resources[0].updated_by == "system" - assert hasattr(service.resources[0], "initial_request_time") - assert isinstance(service.resources[0].initial_request_time, datetime) - - def test_registration_duplicate_user( test_client, valid_registration_data, mock_auth0_client ): diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index ae290273..00000000 --- a/tests/test_models.py +++ /dev/null @@ -1,129 +0,0 @@ -from datetime import datetime, timedelta - -import pytest -from freezegun import freeze_time - -from schemas.service import Resource, Service -from tests.datagen import AppMetadataFactory - -FROZEN_TIME = datetime(2025, 1, 1, 12, 0, 0) - - -@pytest.fixture -def frozen_time(): - """ - Freeze time so datetime.now() returns FROZEN_TIME. - """ - with freeze_time("2025-01-01 12:00:00"): - yield - - -def test_approve_service(frozen_time): - """Test we can approve a service and set metadata correctly.""" - service = Service(name="Test Service", id="service1", status="pending", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") - service.approve(updated_by="admin@example.com") - assert service.status == "approved" - assert service.updated_by == "admin@example.com" - assert service.last_updated == FROZEN_TIME - - -def test_approve_service_from_app_metadata(frozen_time): - """Test we can approve a service by ID from BiocommonsAppMetadata.""" - service = Service(name="Test Service", id="service1", status="pending", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") - other = Service(name="Other Service", id="service2", status="pending", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") - app_metadata = AppMetadataFactory.build(services=[service, other]) - app_metadata.approve_service(service_id="service1", updated_by="admin@example.com") - assert service.status == "approved" - assert service.updated_by == "admin@example.com" - assert service.last_updated == FROZEN_TIME - assert other.status == "pending" - - -def test_revoke_service(frozen_time): - """Test we can revoke a service and set metadata correctly.""" - service = Service(name="Test Service", id="service1", status="approved", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") - service.revoke(updated_by="admin@example.com") - assert service.status == "revoked" - assert service.updated_by == "admin@example.com" - assert service.last_updated == FROZEN_TIME - - -def test_revoke_service_from_app_metadata(frozen_time): - """Test we can revoke a service by ID from BiocommonsAppMetadata.""" - service = Service(name="Test Service", id="service1", status="approved", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") - other = Service(name="Other Service", id="service2", status="approved", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="") - app_metadata = AppMetadataFactory.build(services=[service, other]) - app_metadata.revoke_service(service_id="service1", updated_by="admin@example.com") - assert service.status == "revoked" - assert service.updated_by == "admin@example.com" - assert service.last_updated == FROZEN_TIME - assert other.status == "approved" - - -def test_approve_resource(frozen_time): - resource = Resource( - name="Test Resource", - id="resource1", - status="pending", - initial_request_time=FROZEN_TIME - ) - resource.approve() - assert resource.status == "approved" - assert resource.initial_request_time == FROZEN_TIME - - -def test_approve_resource_from_service(frozen_time): - """Test that trying to approve a resource from a pending service raises an error.""" - resource = Resource( - name="Test Resource", - id="resource1", - status="pending", - initial_request_time=FROZEN_TIME - ) - service = Service(name="Test Service", id="service1", status="approved", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="", - resources=[resource]) - service.approve_resource(resource_id="resource1") - assert resource.status == "approved" - assert resource.initial_request_time == FROZEN_TIME - - -def test_approve_resource_from_pending_service(frozen_time): - """Test that trying to approve a resource from a pending service raises an error.""" - resource = Resource( - name="Test Resource", - id="resource1", - status="pending", - initial_request_time=FROZEN_TIME - ) - - service = Service(name="Test Service", id="service1", status="pending", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="", - resources=[resource]) - with pytest.raises(PermissionError, match="Service must be approved before approving a resource."): - service.approve_resource(resource_id="resource1") - assert resource.status == "pending" - assert resource.initial_request_time == FROZEN_TIME - - -def test_approve_resource_from_app_metadata(frozen_time): - resource = Resource( - name="Test Resource", - id="resource1", - status="pending", - initial_request_time=FROZEN_TIME - ) - service = Service(name="Test Service", id="service1", status="approved", - last_updated=FROZEN_TIME - timedelta(hours=1), updated_by="", - resources=[resource]) - app_metadata = AppMetadataFactory.build(services=[service]) - app_metadata.approve_resource(service_id="service1", resource_id="resource1", updated_by="admin@example.com") - - assert resource.status == "approved" - assert resource.initial_request_time == FROZEN_TIME diff --git a/tests/test_user.py b/tests/test_user.py index eba70332..23211deb 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -1,11 +1,18 @@ -from datetime import datetime - import pytest -from fastapi import HTTPException +from db.types import ApprovalStatusEnum, PlatformEnum from schemas.biocommons import BiocommonsAppMetadata -from schemas.service import Group, Resource, Service -from tests.datagen import AccessTokenPayloadFactory, Auth0UserDataFactory +from tests.datagen import ( + AccessTokenPayloadFactory, + Auth0UserDataFactory, + SessionUserFactory, +) +from tests.db.datagen import ( + BiocommonsGroupFactory, + BiocommonsUserFactory, + GroupMembershipFactory, + PlatformMembershipFactory, +) # --- Test Fixtures --- @@ -31,32 +38,7 @@ def auth_headers(): def mock_user_data(): """Fixture to provide mock user data""" return Auth0UserDataFactory.build( - app_metadata=BiocommonsAppMetadata( - groups=[Group(name="Australian University", id="AU")], - services=[ - Service( - id="service1", - name="Service 1", - status="approved", - last_updated=datetime.now(), - updated_by="test@example.com", - resources=[ - Resource(id="resource1", name="Resource 1", status="approved"), - Resource(id="resource2", name="Resource 2", status="pending"), - ], - ), - Service( - id="service2", - name="Service 2", - status="pending", - last_updated=datetime.now(), - updated_by="test@example.com", - resources=[ - Resource(id="resource3", name="Resource 3", status="pending") - ], - ), - ], - ), + app_metadata=BiocommonsAppMetadata(registration_from="biocommons"), ) @@ -64,12 +46,13 @@ def mock_user_data(): @pytest.mark.parametrize( "endpoint", [ - "/me/services", - "/me/services/approved", - "/me/services/pending", - "/me/resources", - "/me/resources/approved", - "/me/resources/pending", + "/me/is-admin", + "/me/platforms", + "/me/platforms/approved", + "/me/platforms/pending", + "/me/groups", + "/me/groups/approved", + "/me/groups/pending", "/me/all/pending", ], ) @@ -80,409 +63,6 @@ def test_endpoints_require_auth(endpoint, test_client): assert response.json() == {"detail": "Not authenticated"} -# --- Service Endpoints (GET) --- -def test_get_all_services( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test getting all services""" - mocker.patch( - "routers.user.get_user_data", # Changed from fetch_user_data - return_value=mock_user_data, - ) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/services", headers=auth_headers) - assert response.status_code == 200 - - expected_services = [ - s.model_dump(mode="json") for s in mock_user_data.app_metadata.services - ] - assert response.json() == {"services": expected_services} - - -def test_get_approved_services( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test getting approved services""" - mocker.patch( - "routers.user.get_user_data", # Changed from fetch_user_data - return_value=mock_user_data, - ) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/services/approved", headers=auth_headers) - assert response.status_code == 200 - - approved_services = [ - s.model_dump(mode="json") - for s in mock_user_data.app_metadata.services - if s.status == "approved" - ] - assert response.json() == {"approved_services": approved_services} - - -def test_get_pending_services( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test getting pending services""" - mocker.patch( - "routers.user.get_user_data", # Changed from fetch_user_data - return_value=mock_user_data, - ) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/services/pending", headers=auth_headers) - assert response.status_code == 200 - - pending_services = [ - s.model_dump(mode="json") - for s in mock_user_data.app_metadata.services - if s.status == "pending" - ] - assert response.json() == {"pending_services": pending_services} - - -def test_get_services_failed_fetch( - mock_auth_token, auth_headers, mocker, - test_client -): - """Test handling of failed API calls""" - mocker.patch( - "routers.user.get_user_data", # Changed from fetch_user_data - side_effect=HTTPException(status_code=403, detail="Failed to fetch user data"), - ) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/services", headers=auth_headers) - assert response.status_code == 403 - assert response.json() == {"detail": "Failed to fetch user data"} - - -def test_get_services_empty_metadata( - mock_auth_token, auth_headers, mocker, test_client -): - """Test handling of empty metadata""" - empty_user = Auth0UserDataFactory.build( - app_metadata=BiocommonsAppMetadata(services=[], groups=[]), - ) - mocker.patch("routers.user.get_user_data", return_value=empty_user) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/services", headers=auth_headers) - assert response.status_code == 200 - assert response.json() == {"services": []} - - -def test_get_services_no_metadata( - mock_auth_token, auth_headers, mocker, - test_client -): - """Test handling of missing metadata""" - no_metadata_user = Auth0UserDataFactory.build(app_metadata=BiocommonsAppMetadata()) - mocker.patch("routers.user.get_user_data", return_value=no_metadata_user) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/services", headers=auth_headers) - assert response.status_code == 200 - assert response.json() == {"services": []} - - -# --- Resource Endpoints (GET) --- -def test_get_all_resources( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test getting all resources""" - mocker.patch( - "routers.user.get_user_data", # Changed from fetch_user_data - return_value=mock_user_data, - ) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/resources", headers=auth_headers) - assert response.status_code == 200 - all_resources = [ - r.model_dump() - for s in mock_user_data.app_metadata.services - for r in s.resources - ] - assert response.json() == {"resources": all_resources} - - -def test_get_approved_resources( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test getting approved resources""" - mocker.patch( - "routers.user.get_user_data", # Changed from fetch_user_data - return_value=mock_user_data, - ) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/resources/approved", headers=auth_headers) - assert response.status_code == 200 - approved_resources = [ - r.model_dump() - for s in mock_user_data.app_metadata.services - for r in s.resources - if r.status == "approved" - ] - assert response.json() == {"approved_resources": approved_resources} - - -def test_get_resources_empty_metadata( - mock_auth_token, auth_headers, mocker, test_client -): - """Test handling of empty resource metadata""" - empty_user = Auth0UserDataFactory.build(app_metadata=BiocommonsAppMetadata(services=[], groups=[]), - ) - mocker.patch("routers.user.get_user_data", return_value=empty_user) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/resources", headers=auth_headers) - assert response.status_code == 200 - assert response.json() == {"resources": []} - - -def test_get_resources_no_metadata( - mock_auth_token, auth_headers, mocker, - test_client -): - """Test handling of missing resource metadata""" - no_metadata_user = Auth0UserDataFactory.build(app_metadata=BiocommonsAppMetadata()) - mocker.patch("routers.user.get_user_data", return_value=no_metadata_user) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - response = test_client.get("/me/resources", headers=auth_headers) - assert response.status_code == 200 - assert response.json() == {"resources": []} - - -# --- Service Request Endpoints (POST) --- -def test_request_service_success( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test successful service request""" - mocker.patch("routers.user.get_user_data", return_value=mock_user_data) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - mocker.patch("routers.user.update_user_metadata", return_value={}) - - new_service = { - "name": "New Service", - "id": "service3", - "user_id": mock_auth_token.sub, - } - - response = test_client.post( - "/me/request/service", json=new_service, headers=auth_headers - ) - assert response.status_code == 200 - assert response.json()["message"] == "Service request submitted successfully" - assert response.json()["service"]["id"] == "service3" - - -def test_request_service_duplicate( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test duplicate service request""" - mocker.patch("routers.user.get_user_data", return_value=mock_user_data) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - existing_service = { - "name": "Service 1", - "id": "service1", - "user_id": mock_auth_token.sub, - } - - response = test_client.post( - "/me/request/service", json=existing_service, headers=auth_headers - ) - assert response.status_code == 400 - assert ( - response.json()["detail"] == "Service request with ID service1 already exists" - ) - - -def test_request_service_user_mismatch( - mock_auth_token, auth_headers, mock_user_data, - test_client -): - """Test service request with mismatched user""" - request_payload = { - "name": "Service Mismatch", - "id": "svc-mismatch", - "user_id": "auth0|WRONG_USER", - } - - response = test_client.post( - "/me/request/service", json=request_payload, headers=auth_headers - ) - assert response.status_code == 403 - assert ( - response.json()["detail"] - == "User ID in request does not match authenticated user" - ) - - -# --- Resource Request Endpoints (POST) --- -def test_request_resource_success( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test successful resource request""" - mocker.patch("routers.user.get_user_data", return_value=mock_user_data) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - mocker.patch("routers.user.update_user_metadata", return_value={}) - - request_payload = { - "name": "New Resource", - "id": "resource-new", - "user_id": mock_auth_token.sub, - "service_id": "service1", - } - - response = test_client.post( - "/me/request/service1/resource-new", json=request_payload, headers=auth_headers - ) - assert response.status_code == 200 - assert response.json()["resource"]["id"] == "resource-new" - - -def test_request_resource_user_mismatch( - mock_auth_token, auth_headers, mock_user_data, - test_client -): - """Test resource request with mismatched user""" - request_payload = { - "name": "Invalid Resource", - "id": "res-invalid", - "user_id": "wrong-user", - "service_id": "service1", - } - - response = test_client.post( - "/me/request/service1/res-invalid", json=request_payload, headers=auth_headers - ) - assert response.status_code == 403 - assert ( - response.json()["detail"] - == "User ID in request does not match authenticated user" - ) - - -def test_request_resource_non_approved_service( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test resource request for non-approved service""" - mocker.patch("routers.user.get_user_data", return_value=mock_user_data) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - request_payload = { - "name": "Blocked Resource", - "id": "blocked-resource", - "user_id": mock_auth_token.sub, - "service_id": "service2", - } - - response = test_client.post( - "/me/request/service2/blocked-resource", - json=request_payload, - headers=auth_headers, - ) - assert response.status_code == 400 - assert ( - response.json()["detail"] - == "Cannot request resources for a service that is not approved" - ) - - -def test_request_resource_duplicate( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test duplicate resource request""" - mocker.patch("routers.user.get_user_data", return_value=mock_user_data) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - existing_resource = { - "name": "Resource 1", - "id": "resource1", - "user_id": mock_auth_token.sub, - "service_id": "service1", - } - - response = test_client.post( - "/me/request/service1/resource1", json=existing_resource, headers=auth_headers - ) - assert response.status_code == 400 - assert ( - response.json()["detail"] == "Resource request with ID resource1 already exists" - ) - - -def test_request_resource_invalid_service( - mock_auth_token, auth_headers, mock_user_data, mocker, - test_client -): - """Test resource request for non-existent service""" - mocker.patch("routers.user.get_user_data", return_value=mock_user_data) - mocker.patch( - "routers.user.get_management_token", return_value="mock_management_token" - ) - - request_payload = { - "name": "Invalid Service Resource", - "id": "resource-invalid", - "user_id": mock_auth_token.sub, - "service_id": "non-existent-service", - } - - response = test_client.post( - "/me/request/non-existent-service/resource-invalid", - json=request_payload, - headers=auth_headers, - ) - assert response.status_code == 404 - assert response.json()["detail"] == "Service with ID non-existent-service not found" - - def test_check_is_admin_with_admin_role(test_client, mock_settings, mocker): """Test that admin check returns True for users with admin role""" from tests.datagen import SessionUserFactory @@ -529,3 +109,133 @@ def test_check_is_admin_without_authentication(test_client): """Test that admin check requires authentication""" response = test_client.get("/me/is-admin") assert response.status_code == 401 + + +def _act_as_user(mocker, db_user): + """ + Set up mocks so that the test client authenticates as the given user + """ + access_token = AccessTokenPayloadFactory.build(sub=db_user.id) + auth0_user = SessionUserFactory.build(access_token=access_token) + mocker.patch("auth.validator.verify_jwt", return_value=access_token) + mocker.patch("auth.validator.get_current_user", return_value=auth0_user) + return auth0_user + + +def test_get_platforms(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns list of platforms""" + user = BiocommonsUserFactory.create_sync() + PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.GALAXY, approval_status=ApprovalStatusEnum.APPROVED) + PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.BPA_DATA_PORTAL, approval_status=ApprovalStatusEnum.PENDING) + test_db_session.flush() + _act_as_user(mocker, user) + response = test_client.get("/me/platforms", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + ids = [platform["platform_id"] for platform in data] + assert PlatformEnum.GALAXY in ids + assert PlatformEnum.BPA_DATA_PORTAL in ids + + +def test_get_approved_platforms(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns list of approved platforms""" + user = BiocommonsUserFactory.create_sync() + PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.GALAXY, approval_status=ApprovalStatusEnum.APPROVED) + PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.BPA_DATA_PORTAL, approval_status=ApprovalStatusEnum.PENDING) + test_db_session.flush() + _act_as_user(mocker, user) + response = test_client.get("/me/platforms/approved", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + ids = [platform["platform_id"] for platform in data] + assert PlatformEnum.GALAXY in ids + assert data[0]["approval_status"] == "approved" + + +def test_get_pending_platforms(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns list of pending platforms""" + user = BiocommonsUserFactory.create_sync() + PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.GALAXY, approval_status=ApprovalStatusEnum.APPROVED) + PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.BPA_DATA_PORTAL, approval_status=ApprovalStatusEnum.PENDING) + test_db_session.flush() + _act_as_user(mocker, user) + response = test_client.get("/me/platforms/pending", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + ids = [platform["platform_id"] for platform in data] + assert PlatformEnum.BPA_DATA_PORTAL in ids + assert data[0]["approval_status"] == "pending" + + +def test_get_groups(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns list of groups""" + user = BiocommonsUserFactory.create_sync() + groups = BiocommonsGroupFactory.create_batch_sync(size=2) + GroupMembershipFactory.create_sync(user=user, group=groups[0], approval_status=ApprovalStatusEnum.APPROVED) + GroupMembershipFactory.create_sync(user=user, group=groups[1], approval_status=ApprovalStatusEnum.PENDING) + test_db_session.flush() + _act_as_user(mocker, user) + response = test_client.get("/me/groups", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + ids = [group["group_id"] for group in data] + assert all(group.group_id in ids for group in groups) + names = [group["group_name"] for group in data] + assert all(group.name in names for group in groups) + + +def test_get_approved_groups(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns list of approved groups""" + user = BiocommonsUserFactory.create_sync() + groups = BiocommonsGroupFactory.create_batch_sync(size=2) + approved_group = GroupMembershipFactory.create_sync(user=user, group=groups[0], approval_status=ApprovalStatusEnum.APPROVED) + GroupMembershipFactory.create_sync(user=user, group=groups[1], approval_status=ApprovalStatusEnum.PENDING) + test_db_session.flush() + _act_as_user(mocker, user) + response = test_client.get("/me/groups/approved", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + ids = [group["group_id"] for group in data] + assert approved_group.group_id in ids + + +def test_get_pending_groups(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns list of pending groups""" + user = BiocommonsUserFactory.create_sync() + groups = BiocommonsGroupFactory.create_batch_sync(size=2) + GroupMembershipFactory.create_sync(user=user, group=groups[0], approval_status=ApprovalStatusEnum.APPROVED) + pending_group = GroupMembershipFactory.create_sync(user=user, group=groups[1], approval_status=ApprovalStatusEnum.PENDING) + test_db_session.flush() + _act_as_user(mocker, user) + response = test_client.get("/me/groups/pending", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + ids = [group["group_id"] for group in data] + assert pending_group.group_id in ids + + +def test_get_all_pending(test_client, test_db_session, mocker, persistent_factories): + """Test that endpoint returns combined list of pending groups and platforms""" + user = BiocommonsUserFactory.create_sync() + groups = BiocommonsGroupFactory.create_batch_sync(size=2) + GroupMembershipFactory.create_sync(user=user, group=groups[0], approval_status=ApprovalStatusEnum.APPROVED) + pending_group = GroupMembershipFactory.create_sync(user=user, group=groups[1], approval_status=ApprovalStatusEnum.PENDING) + PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.GALAXY, approval_status=ApprovalStatusEnum.APPROVED) + pending_platform = PlatformMembershipFactory.create_sync(user=user, platform_id=PlatformEnum.BPA_DATA_PORTAL, approval_status=ApprovalStatusEnum.PENDING) + test_db_session.flush() + _act_as_user(mocker, user) + response = test_client.get("/me/all/pending", headers={"Authorization": "Bearer valid_token"}) + assert response.status_code == 200 + data = response.json() + assert len(data["groups"]) == 1 + group_ids = [group["group_id"] for group in data["groups"]] + assert pending_group.group_id in group_ids + assert len(data["platforms"]) == 1 + platform_ids = [platform["platform_id"] for platform in data["platforms"]] + assert pending_platform.platform_id in platform_ids