From bc8618d87a07797e1f14faa6f5d510e8b9ff66b0 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 09:34:19 +1000 Subject: [PATCH 01/23] Separate organization parsing into a helper function --- routers/bpa_register.py | 43 ++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/routers/bpa_register.py b/routers/bpa_register.py index 8088fd9d..063befbf 100644 --- a/routers/bpa_register.py +++ b/routers/bpa_register.py @@ -34,6 +34,28 @@ def send_approval_email(registration: BPARegistrationRequest, bpa_resources: lis email_service.send(approver_email, subject, body_html) +def _get_bpa_resources(registration: BPARegistrationRequest, settings: Settings, update_time: datetime) -> list[Resource]: + bpa_resources = [] + for org_id, is_selected in registration.organizations.items(): + if not is_selected: + continue + if org_id not in settings.organizations: + raise HTTPException( + status_code=400, detail=f"Invalid organization ID: {org_id}" + ) + resource = Resource( + id=org_id, + name=settings.organizations[org_id], + status="pending", + last_updated=update_time, + initial_request_time=update_time, + updated_by="system", + ).model_dump(mode="json") + bpa_resources.append(resource) + return bpa_resources + + + @router.post( "/register", response_model=Dict[str, Any], @@ -55,26 +77,7 @@ async def register_bpa_user( now = datetime.now(timezone.utc) - # Create BPA resources - bpa_resources = [] - for org_id, is_selected in registration.organizations.items(): - if not is_selected: - continue - if org_id not in settings.organizations: - raise HTTPException( - status_code=400, detail=f"Invalid organization ID: {org_id}" - ) - resource = Resource( - id=org_id, - name=settings.organizations[org_id], - status="pending", - last_updated=now, - initial_request_time=now, - updated_by="system", - ).model_dump(mode="json") - bpa_resources.append(resource) - - # Create BPA service + bpa_resources = _get_bpa_resources(registration, settings, update_time=now) bpa_service = Service( name="Bioplatforms Australia Data Portal", id="bpa", From de3d7081374d1b45feb5e672434cf3bb3274f653 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 11:32:50 +1000 Subject: [PATCH 02/23] create_user() endpoint for auth0 client --- auth0/client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/auth0/client.py b/auth0/client.py index a704b187..202759d1 100644 --- a/auth0/client.py +++ b/auth0/client.py @@ -8,7 +8,7 @@ from auth.management import get_management_token from config import Settings, get_settings -from schemas.biocommons import Auth0UserData +from schemas.biocommons import Auth0UserData, BiocommonsRegisterData class RoleData(BaseModel): @@ -94,6 +94,12 @@ def get_user(self, user_id: str) -> Auth0UserData: resp = self._client.get(url) return Auth0UserData(**resp.json()) + def create_user(self, user: BiocommonsRegisterData) -> Auth0UserData: + url = f"https://{self.domain}/api/v2/users" + resp = self._client.post(url, json=user.model_dump(mode="json")) + resp.raise_for_status() + return Auth0UserData(**resp.json()) + def add_roles_to_user(self, user_id: str, role_id: str | list[str]): """ Add one or more roles to a user. The role(s) must already exist. From 3639f90e4fbef23a4d138c81c69dc38e04ff5727 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 11:33:21 +1000 Subject: [PATCH 03/23] Create DB user from Auth0 data - separate creation from the API call --- db/models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/db/models.py b/db/models.py index 741f608c..8a69e613 100644 --- a/db/models.py +++ b/db/models.py @@ -10,6 +10,7 @@ from auth0.client import Auth0Client from db.core import BaseModel +from schemas.biocommons import Auth0UserData from schemas.user import SessionUser @@ -44,14 +45,17 @@ class BiocommonsUser(BaseModel, table=True): ) @classmethod - def create_from_auth0(cls, auth0_id: str, auth0_client: Auth0Client): + def create_from_auth0(cls, auth0_id: str, auth0_client: Auth0Client) -> Self: user_data = auth0_client.get_user(user_id=auth0_id) - user = cls( - id=auth0_id, - email=user_data.email, - username=user_data.username + return cls.from_auth0_data(user_data) + + @classmethod + def from_auth0_data(cls, data: Auth0UserData) -> Self: + return cls( + id=data.user_id, + email=data.email, + username=data.username ) - return user @classmethod def get_or_create(cls, auth0_id: str, db_session: Session, auth0_client: Auth0Client) -> Self: From bf0586da6f5a7da574d06ab3f35ab799cfc121e4 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 11:35:19 +1000 Subject: [PATCH 04/23] Docstrings for create user methods --- db/models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/db/models.py b/db/models.py index 8a69e613..356da23b 100644 --- a/db/models.py +++ b/db/models.py @@ -46,11 +46,17 @@ class BiocommonsUser(BaseModel, table=True): @classmethod def create_from_auth0(cls, auth0_id: str, auth0_client: Auth0Client) -> Self: + """ + Get user data from Auth0 API and create a new BiocommonsUser object. + """ user_data = auth0_client.get_user(user_id=auth0_id) return cls.from_auth0_data(user_data) @classmethod def from_auth0_data(cls, data: Auth0UserData) -> Self: + """ + Create a new BiocommonsUser object from Auth0 user data (no API call). + """ return cls( id=data.user_id, email=data.email, From 60150f510c9b7ad08a0ab188f463e6354e7f388a Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 11:39:06 +1000 Subject: [PATCH 05/23] Update BPA register function to add user to the DB --- routers/bpa_register.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/routers/bpa_register.py b/routers/bpa_register.py index 063befbf..098cb944 100644 --- a/routers/bpa_register.py +++ b/routers/bpa_register.py @@ -1,16 +1,21 @@ +import logging from datetime import datetime, timezone from typing import Any, Dict from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException -from httpx import AsyncClient +from sqlmodel import Session -from auth.management import get_management_token from auth.ses import EmailService +from auth0.client import Auth0Client, get_auth0_client from config import Settings, get_settings +from db.models import BiocommonsUser +from db.setup import get_db_session from schemas.biocommons import BiocommonsRegisterData from schemas.bpa import BPARegistrationRequest from schemas.service import Resource, Service +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/bpa", tags=["bpa", "registration"]) @@ -68,13 +73,11 @@ def _get_bpa_resources(registration: BPARegistrationRequest, settings: Settings, async def register_bpa_user( registration: BPARegistrationRequest, background_tasks: BackgroundTasks, - settings: Settings = Depends(get_settings) + settings: Settings = Depends(get_settings), + db_session: Session = Depends(get_db_session), + auth0_client: Auth0Client = Depends(get_auth0_client) ) -> Dict[str, Any]: """Register a new BPA user with selected organization resources.""" - url = f"https://{settings.auth0_domain}/api/v2/users" - token = get_management_token(settings=settings) - headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} - now = datetime.now(timezone.utc) bpa_resources = _get_bpa_resources(registration, settings, update_time=now) @@ -94,20 +97,18 @@ async def register_bpa_user( ) try: - async with AsyncClient() as client: - response = await client.post( - url, headers=headers, json=user_data.model_dump(mode="json") - ) - if response.status_code != 201: - raise HTTPException( - status_code=400, - detail=f"Registration failed: {response.json()['message']}", - ) + logger.info("Registering user with Auth0") + auth0_user_data = auth0_client.create_user(user_data) + + logger.info("Adding user to DB") + db_user = BiocommonsUser.from_auth0_data(data=auth0_user_data) + db_session.add(db_user) + db_session.commit() if bpa_resources and settings.send_email: background_tasks.add_task(send_approval_email, registration, bpa_resources) - return {"message": "User registered successfully", "user": response.json()} + return {"message": "User registered successfully", "user": auth0_user_data} except HTTPException: raise From 33c9ed22b5eb7aed491e123f064a1093a2e96f5f Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 11:39:59 +1000 Subject: [PATCH 06/23] Update registration test --- tests/test_bpa_register.py | 79 +++++++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/tests/test_bpa_register.py b/tests/test_bpa_register.py index 6d04b794..4de6bf6a 100644 --- a/tests/test_bpa_register.py +++ b/tests/test_bpa_register.py @@ -3,9 +3,17 @@ import pytest +from auth0.client import get_auth0_client +from db.models import BiocommonsUser +from main import app from schemas import Service from schemas.biocommons import BiocommonsRegisterData -from tests.datagen import AccessTokenPayloadFactory, BPARegistrationDataFactory +from tests.datagen import ( + AccessTokenPayloadFactory, + Auth0UserDataFactory, + BPARegistrationDataFactory, + random_auth0_id, +) @pytest.fixture @@ -33,6 +41,17 @@ def mock_auth_token(mocker): return token +@pytest.fixture +def override_auth0_client(mocker): + def override_auth0_client(): + return mock_client + + mock_client = mocker.patch("routers.utils.Auth0Client")() + app.dependency_overrides[get_auth0_client] = override_auth0_client + yield mock_client + app.dependency_overrides.clear() + + def test_to_biocommons_register_data(valid_registration_data): bpa_data = BPARegistrationDataFactory.build() bpa_service = Service( @@ -51,16 +70,13 @@ def test_to_biocommons_register_data(valid_registration_data): def test_successful_registration( - test_client_with_email, mock_auth_token, mocker, valid_registration_data, + test_client_with_email, mocker, valid_registration_data, + override_auth0_client, mock_auth0_client, test_db_session ): """Test successful user registration with BPA service""" test_client = test_client_with_email - mock_response = MagicMock() - mock_response.status_code = 201 - mock_response.json.return_value = {"user_id": "auth0|123"} - - mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) - + user_id = random_auth0_id() + mock_auth0_client.create_user.return_value = Auth0UserDataFactory.build(user_id=user_id) mock_email_cls = mocker.patch("routers.bpa_register.EmailService", autospec=True) mock_email_cls.return_value.send.return_value = None @@ -70,30 +86,33 @@ def test_successful_registration( assert response.json()["message"] == "User registered successfully" mock_email_cls.return_value.send.assert_called_once() - - called_data = mock_post.call_args[1]["json"] - assert called_data["email"] == valid_registration_data["email"] - assert called_data["username"] == valid_registration_data["username"] - assert called_data["name"] == valid_registration_data["fullname"] - - app_metadata = called_data["app_metadata"] - assert len(app_metadata["services"]) == 1 - bpa_service = app_metadata["services"][0] - assert bpa_service["name"] == "Bioplatforms Australia Data Portal" - assert bpa_service["status"] == "pending" - assert "last_updated" in bpa_service - assert "updated_by" in bpa_service - assert bpa_service["updated_by"] == "system" - assert len(bpa_service["resources"]) == 2 - - for resource in bpa_service["resources"]: - assert "last_updated" in resource - assert "updated_by" in resource - assert "initial_request_time" in resource - assert resource["updated_by"] == "system" + # Check user is created in the database + db_user = test_db_session.get(BiocommonsUser, user_id) + assert db_user is not None + assert db_user.id == user_id + + called_data = mock_auth0_client.create_user.call_args[0][0] + assert called_data.email == valid_registration_data["email"] + assert called_data.username == valid_registration_data["username"] + assert called_data.name == valid_registration_data["fullname"] + assert not called_data.email_verified + + app_metadata = called_data.app_metadata + assert len(app_metadata.services) == 1 + bpa_service = app_metadata.services[0] + assert bpa_service.name == "Bioplatforms Australia Data Portal" + assert bpa_service.status == "pending" + assert bpa_service.last_updated is not None + assert bpa_service.updated_by == "system" + assert len(bpa_service.resources) == 2 + + for resource in bpa_service.resources: + assert resource.last_updated is not None + assert resource.initial_request_time is not None + assert resource.updated_by == "system" assert ( - called_data["user_metadata"]["bpa"]["registration_reason"] + called_data.user_metadata.bpa.registration_reason == valid_registration_data["reason"] ) From a5ea2b6568926ab62c5f802cf71775256618b1d3 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 13:00:49 +1000 Subject: [PATCH 07/23] Update error handling in BPA register endpoint --- routers/bpa_register.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/routers/bpa_register.py b/routers/bpa_register.py index 098cb944..96106e26 100644 --- a/routers/bpa_register.py +++ b/routers/bpa_register.py @@ -3,6 +3,7 @@ from typing import Any, Dict from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from httpx import HTTPStatusError from sqlmodel import Session from auth.ses import EmailService @@ -110,8 +111,8 @@ async def register_bpa_user( return {"message": "User registered successfully", "user": auth0_user_data} - except HTTPException: - raise + except HTTPStatusError as e: + raise HTTPException(status_code=e.response.status_code, detail=e.response.text) except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to register user: {str(e)}" From e1e8af01f6cbba42a17af95a2e9738be58bcf119 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 13:01:33 +1000 Subject: [PATCH 08/23] Always override get_management_token when using test_client --- tests/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1e103a5e..137dd969 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,9 +121,10 @@ def test_client(mock_settings, mock_galaxy_settings): def override_settings(): return mock_settings - # Apply override + # Apply overrides app.dependency_overrides[get_settings] = override_settings app.dependency_overrides[get_galaxy_settings] = lambda: mock_galaxy_settings + app.dependency_overrides[get_management_token] = lambda: "mock_token" # Create client client = TestClient(app) From f446e65e7f62ab0e009409aefee2e9270a670cd0 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 13:05:29 +1000 Subject: [PATCH 09/23] Update BPA registration tests --- tests/test_bpa_register.py | 98 ++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/tests/test_bpa_register.py b/tests/test_bpa_register.py index 4de6bf6a..66a04d0c 100644 --- a/tests/test_bpa_register.py +++ b/tests/test_bpa_register.py @@ -1,6 +1,6 @@ from datetime import UTC, datetime -from unittest.mock import MagicMock +import httpx import pytest from auth0.client import get_auth0_client @@ -9,7 +9,6 @@ from schemas import Service from schemas.biocommons import BiocommonsRegisterData from tests.datagen import ( - AccessTokenPayloadFactory, Auth0UserDataFactory, BPARegistrationDataFactory, random_auth0_id, @@ -29,18 +28,6 @@ def valid_registration_data(): ).model_dump() -@pytest.fixture -def mock_auth_token(mocker): - token = AccessTokenPayloadFactory.build( - sub="auth0|123456789", - biocommons_roles=["acdc/indexd_admin"], - ) - mocker.patch("auth.validator.verify_jwt", return_value=token) - mocker.patch("auth.management.get_management_token", return_value="mock_token") - mocker.patch("routers.bpa_register.get_management_token", return_value="mock_token") - return token - - @pytest.fixture def override_auth0_client(mocker): def override_auth0_client(): @@ -139,31 +126,35 @@ def test_service_and_resources_have_updated_by_system(): assert hasattr(service.resources[0], "initial_request_time") assert isinstance(service.resources[0].initial_request_time, datetime) + def test_registration_duplicate_user( - test_client, mock_auth_token, mocker, valid_registration_data + test_client, valid_registration_data, override_auth0_client, mock_auth0_client ): """Test registration with duplicate user""" - mock_response = MagicMock() - mock_response.status_code = 409 - mock_response.json.return_value = {"message": "User already exists"} - - mocker.patch("httpx.AsyncClient.post", return_value=mock_response) + error = httpx.HTTPStatusError( + "User already exists", + request=httpx.Request("POST", "https://api.example.com/data"), + response=httpx.Response(409, text="Registration failed: User already exists"), + ) + mock_auth0_client.create_user.side_effect = error response = test_client.post("/bpa/register", json=valid_registration_data) - assert response.status_code == 400 + assert response.status_code == 409 assert response.json()["detail"] == "Registration failed: User already exists" def test_registration_auth0_error( - test_client, mock_auth_token, mocker, valid_registration_data + test_client, mock_auth0_client, valid_registration_data ): """Test registration with Auth0 API error""" - mock_response = MagicMock() - mock_response.status_code = 400 - mock_response.json.return_value = {"message": "Invalid request"} + error = httpx.HTTPStatusError( + "User already exists", + request=httpx.Request("POST", "https://api.example.com/data"), + response=httpx.Response(400, text="Registration failed: Invalid request"), + ) + mock_auth0_client.create_user.side_effect = error - mocker.patch("httpx.AsyncClient.post", return_value=mock_response) response = test_client.post("/bpa/register", json=valid_registration_data) @@ -172,7 +163,7 @@ def test_registration_auth0_error( def test_registration_with_invalid_organization( - test_client, mock_auth_token, mocker, valid_registration_data + test_client, valid_registration_data ): """Test registration with invalid organization ID""" data = valid_registration_data.copy() @@ -198,7 +189,7 @@ def test_registration_request_validation(test_client): def test_no_selected_organizations( - test_client, mock_auth_token, mocker, valid_registration_data + test_client, test_db_session, mock_auth0_client, valid_registration_data ): """Test registration with no organizations selected""" data = valid_registration_data.copy() @@ -207,40 +198,33 @@ def test_no_selected_organizations( "cipps": False, "ausarg": False, } - - mock_response = MagicMock() - mock_response.status_code = 201 - mock_response.json.return_value = {"user_id": "auth0|123"} - - mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) + user_data = Auth0UserDataFactory.build() + mock_auth0_client.create_user.return_value = user_data response = test_client.post("/bpa/register", json=data) assert response.status_code == 200 - called_data = mock_post.call_args[1]["json"] - bpa_service = called_data["app_metadata"]["services"][0] - assert len(bpa_service["resources"]) == 0 + # Check user data sent to Auth0 + called_data = mock_auth0_client.create_user.call_args[0][0] + bpa_service = called_data.app_metadata.services[0] + assert len(bpa_service.resources) == 0 def test_empty_organizations_dict( - test_client, mock_auth_token, mocker, valid_registration_data + test_client, test_db_session, mock_auth0_client, valid_registration_data ): """Test registration with empty organizations dictionary""" data = valid_registration_data.copy() data["organizations"] = {} - - mock_response = MagicMock() - mock_response.status_code = 201 - mock_response.json.return_value = {"user_id": "auth0|123"} - - mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) + user_data = Auth0UserDataFactory.build() + mock_auth0_client.create_user.return_value = user_data response = test_client.post("/bpa/register", json=data) assert response.status_code == 200 - called_data = mock_post.call_args[1]["json"] - bpa_service = called_data["app_metadata"]["services"][0] - assert len(bpa_service["resources"]) == 0 + called_data = mock_auth0_client.create_user.call_args[0][0] + bpa_service = called_data.app_metadata.services[0] + assert len(bpa_service.resources) == 0 def test_registration_email_format(test_client, valid_registration_data): @@ -256,29 +240,27 @@ def test_registration_email_format(test_client, valid_registration_data): def test_all_organizations_selected( test_client_with_email, - mock_auth_token, mock_settings, - mocker, + mocker, + override_auth0_client, + mock_auth0_client, valid_registration_data, ): """Test registration with all organizations selected""" - test_client = test_client_with_email data = valid_registration_data.copy() data["organizations"] = {k: True for k in mock_settings.organizations.keys()} - mock_response = MagicMock() - mock_response.status_code = 201 - mock_response.json.return_value = {"user_id": "auth0|123"} - mock_post = mocker.patch("httpx.AsyncClient.post", return_value=mock_response) + user_data = Auth0UserDataFactory.build() + mock_auth0_client.create_user.return_value = user_data email_service_cls = mocker.patch("routers.bpa_register.EmailService", autospec=True) email_service_cls.return_value.send.return_value = True - response = test_client.post("/bpa/register", json=data) + response = test_client_with_email.post("/bpa/register", json=data) assert response.status_code == 200 - called_data = mock_post.call_args[1]["json"] - bpa_service = called_data["app_metadata"]["services"][0] - assert len(bpa_service["resources"]) == len(mock_settings.organizations) + called_data = mock_auth0_client.create_user.call_args[0][0] + bpa_service = called_data.app_metadata.services[0] + assert len(bpa_service.resources) == len(mock_settings.organizations) email_service_cls.return_value.send.assert_called_once() From 622d6f9f5840b2cfe87af07053c78caf90c6cd5c Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 16:20:06 +1000 Subject: [PATCH 10/23] Rework Galaxy registration to add a user record in the DB --- routers/galaxy_register.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/routers/galaxy_register.py b/routers/galaxy_register.py index cc0eeb09..3fb3d88b 100644 --- a/routers/galaxy_register.py +++ b/routers/galaxy_register.py @@ -4,9 +4,12 @@ import httpx from fastapi import APIRouter, Header, HTTPException from fastapi.params import Depends +from sqlmodel import Session -from auth.management import get_management_token +from auth0.client import Auth0Client, get_auth0_client from config import Settings, get_settings +from db.models import BiocommonsUser +from db.setup import get_db_session from galaxy.client import GalaxyClient, get_galaxy_client from register.tokens import create_registration_token, verify_registration_token from schemas.biocommons import BiocommonsRegisterData @@ -29,6 +32,8 @@ def register( registration_data: GalaxyRegistrationData, settings: Annotated[Settings, Depends(get_settings)], galaxy_client: Annotated[GalaxyClient, Depends(get_galaxy_client)], + auth0_client: Annotated[Auth0Client, Depends(get_auth0_client)], + db_session: Annotated[Session, Depends(get_db_session)], registration_token: Optional[str] = Header(None), ): if not registration_token: @@ -47,21 +52,13 @@ def register( except httpx.HTTPError as e: logger.warning(f"Failed to check username in Galaxy: {e}") - url = f"https://{settings.auth0_domain}/api/v2/users" - logger.debug("Getting management token.") - management_token = get_management_token(settings=settings) - headers = {"Authorization": f"Bearer {management_token}"} - logger.debug("Registering with Auth0 management API") - resp = httpx.post( - url, - # Use exclude_none so we don't include username/name fields - # when not specified, Auth0 doesn't like this - json=user_data.model_dump( - mode="json", - exclude_none=True - ), - headers=headers - ) - if resp.status_code != 201: - raise HTTPException(status_code=400, detail=f'Registration failed: {resp.json()["message"]}') - return {"message": "User registered successfully", "user": resp.json()} + try: + logger.info("Registering user with Auth0") + auth0_user_data = auth0_client.create_user(user_data) + except httpx.HTTPStatusError as e: + raise HTTPException(status_code=e.response.status_code, detail=f'Registration failed: {e}') + logger.info("Adding user to DB") + db_user = BiocommonsUser.from_auth0_data(data=auth0_user_data) + db_session.add(db_user) + db_session.commit() + return {"message": "User registered successfully", "user": auth0_user_data.model_dump(mode="json")} From 124396d329660597ea72215488273be152a00108 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 16:21:23 +1000 Subject: [PATCH 11/23] Improved test fixtures: better naming of Auth0 clients, freezegun compatibility --- tests/biocommons/test_api.py | 34 +++++++++++++-------------------- tests/biocommons/test_groups.py | 8 ++++---- tests/conftest.py | 28 +++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/tests/biocommons/test_api.py b/tests/biocommons/test_api.py index c19827d0..bdc24015 100644 --- a/tests/biocommons/test_api.py +++ b/tests/biocommons/test_api.py @@ -9,7 +9,6 @@ from sqlmodel import select from auth.validator import get_current_user -from auth0.client import get_auth0_client from db.models import ( Auth0Role, BiocommonsGroup, @@ -31,18 +30,11 @@ ) -@pytest.fixture -def override_auth0_client(auth0_client): - app.dependency_overrides[get_auth0_client] = lambda: auth0_client - yield - app.dependency_overrides.clear() - - @respx.mock -def test_create_group(test_client, as_admin_user, override_auth0_client, test_db_session, persistent_factories): +def test_create_group(test_client, as_admin_user, test_auth0_client, test_db_session, persistent_factories): # Mock Auth0 response to check group exists mock_group = RoleDataFactory.build(name="biocommons/group/tsi") - route = respx.get("https://auth0.example.com/api/v2/roles", params={"name_filter": ANY}).mock( + route = respx.get(f"https://{test_auth0_client.domain}/api/v2/roles", params={"name_filter": ANY}).mock( return_value=Response(200, json=[mock_group.model_dump(mode="json")]) ) @@ -66,14 +58,14 @@ def test_create_group(test_client, as_admin_user, override_auth0_client, test_db @pytest.mark.parametrize("role_name", ["biocommons/role/tsi/admin", "biocommons/group/tsi"]) @respx.mock -def test_create_role(role_name, test_client, as_admin_user, override_auth0_client, test_db_session, mocker): +def test_create_role(role_name, test_client, as_admin_user, test_auth0_client, test_db_session, mocker): """ Test we can create Auth0 roles using either the format for roles or groups. """ mock_resp = RoleDataFactory.build(name=role_name) # Patch check of existing role mocker.patch("auth0.client.Auth0Client.get_role_by_name", side_effect=ValueError) - route = respx.post("https://auth0.example.com/api/v2/roles").mock( + route = respx.post(f"https://{test_auth0_client.domain}/api/v2/roles").mock( return_value=Response(200, json=mock_resp.model_dump(mode="json")) ) resp = test_client.post( @@ -90,7 +82,7 @@ def test_create_role(role_name, test_client, as_admin_user, override_auth0_clien @pytest.mark.parametrize("role_name", ["biocommons/role/tsi/admin", "biocommons/group/tsi"]) -def test_create_role_already_exists(role_name, test_client, as_admin_user, override_auth0_client, test_db_session, mocker): +def test_create_role_already_exists(role_name, test_client, test_auth0_client, as_admin_user, test_db_session, mocker): """ Test we can add existing Auth0 roles to the DB """ @@ -111,7 +103,7 @@ def test_create_role_already_exists(role_name, test_client, as_admin_user, overr @respx.mock -def test_request_group_membership(test_client_with_email, normal_user, as_normal_user, override_auth0_client, test_db_session, persistent_factories, mock_email_service, mocker): +def test_request_group_membership(test_client_with_email, normal_user, as_normal_user, mock_auth0_client, test_db_session, persistent_factories, mock_email_service, mocker): """ Test the full process of requesting group membership - request membership for a user and send approval email to the relevant admins. @@ -122,7 +114,7 @@ def test_request_group_membership(test_client_with_email, normal_user, as_normal user = BiocommonsUserFactory.create_sync(group_memberships=[], id=normal_user.access_token.sub) # Mock an admin that has the required admin role (to send approval email to) admin_info = Auth0UserDataFactory.build(email="admin@example.com") - mocker.patch("db.models.Auth0Client.get_all_role_users", return_value=[admin_info]) + mock_auth0_client.get_all_role_users.return_value = [admin_info] # Request membership resp = test_client.post( "/biocommons/groups/request", @@ -147,7 +139,7 @@ def test_request_group_membership(test_client_with_email, normal_user, as_normal @respx.mock -def test_approve_group_membership(test_client, test_db_session, persistent_factories, override_auth0_client): +def test_approve_group_membership(test_client, test_db_session, persistent_factories, test_auth0_client): """ Test the full approval process, including: * Checking that the group admin has the required role @@ -165,13 +157,13 @@ def test_approve_group_membership(test_client, test_db_session, persistent_facto app.dependency_overrides[get_current_user] = lambda: group_admin # Mock auth0 route for adding roles respx.get( - "https://auth0.example.com/api/v2/roles", + f"https://{test_auth0_client.domain}/api/v2/roles", params={"name_filter": group.group_id} ).respond( 200, json=[RoleDataFactory.build(name=group.group_id).model_dump(mode="json")] ) - route = respx.post(f"https://auth0.example.com/api/v2/users/{user.id}/roles").respond(204) + route = respx.post(f"https://{test_auth0_client.domain}/api/v2/users/{user.id}/roles").respond(204) # Call our group approval endpoint resp = test_client.post( "/biocommons/groups/approve", @@ -190,15 +182,15 @@ def test_approve_group_membership(test_client, test_db_session, persistent_facto assert membership_request.updated_by.email == group_admin.access_token.email -def test_approve_group_membership_invalid_role(test_client, test_db_session, persistent_factories, override_auth0_client): +def test_approve_group_membership_invalid_role(test_client, test_db_session, persistent_factories, test_auth0_client): admin_role = Auth0RoleFactory.create_sync(name="biocommons/role/tsi/admin") group = BiocommonsGroupFactory.create_sync(group_id="biocommons/group/tsi", admin_roles=[admin_role]) access_token = AccessTokenPayloadFactory.build(biocommons_roles=["biocommons/role/biocommons/sysadmin"]) - unauth_admin = SessionUserFactory.build(access_token=access_token) + unauthorized_admin = SessionUserFactory.build(access_token=access_token) user = Auth0UserDataFactory.build() GroupMembershipFactory.create_sync(group=group, user_id=user.user_id, approval_status="pending") # Override get_current_user to return the group admin - app.dependency_overrides[get_current_user] = lambda: unauth_admin + app.dependency_overrides[get_current_user] = lambda: unauthorized_admin resp = test_client.post( "/biocommons/groups/approve", json={ diff --git a/tests/biocommons/test_groups.py b/tests/biocommons/test_groups.py index e48bd19f..b291ea9b 100644 --- a/tests/biocommons/test_groups.py +++ b/tests/biocommons/test_groups.py @@ -51,7 +51,7 @@ def test_biocommons_group_create(): assert group.group_id == "biocommons/group/tsi" -def test_biocommons_group_create_save(test_db_session, auth0_client): +def test_biocommons_group_create_save(test_db_session, test_auth0_client): """ Test saving BiocommonsGroupCreate object to the DB """ @@ -63,7 +63,7 @@ def test_biocommons_group_create_save(test_db_session, auth0_client): name="Threatened Species Initiative", admin_roles=[tsi_admin, sysadmin] ) - group.save(test_db_session, auth0_client=auth0_client) + group.save(test_db_session, auth0_client=test_auth0_client) group_from_db = test_db_session.exec( select(BiocommonsGroup).where(BiocommonsGroup.group_id == group.group_id) ).one() @@ -71,7 +71,7 @@ def test_biocommons_group_create_save(test_db_session, auth0_client): @respx.mock -def test_biocommons_group_save_get_roles(test_db_session, auth0_client, mocker): +def test_biocommons_group_save_get_roles(test_db_session, test_auth0_client): """ Test saving BiocommonsGroupCreate to the DB when roles have to be fetched from Auth0. @@ -85,7 +85,7 @@ def test_biocommons_group_save_get_roles(test_db_session, auth0_client, mocker): name="Threatened Species Initiative", admin_roles=["biocommons/role/tsi/admin"] ) - group.save(test_db_session, auth0_client=auth0_client) + group.save(test_db_session, auth0_client=test_auth0_client) group_from_db = test_db_session.exec( select(BiocommonsGroup).where(BiocommonsGroup.group_id == group.group_id) ).one() diff --git a/tests/conftest.py b/tests/conftest.py index 137dd969..476298f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ import os +from datetime import datetime from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient from moto import mock_aws from moto.core import patch_client +from polyfactory import BaseFactory from sqlmodel import Session, StaticPool, create_engine from auth.management import get_management_token @@ -207,12 +209,23 @@ def mock_galaxy_client(): @pytest.fixture -def auth0_client(): - return Auth0Client(domain="auth0.example.com", management_token="dummy-token") +def test_auth0_client(): + """ + Don't mock the Auth0Client, just return a dummy one. You will need to + mock/patch the actual calls to Auth0. + """ + auth0_client = Auth0Client(domain="auth0.example.com", management_token="dummy-token") + app.dependency_overrides[get_auth0_client] = lambda: auth0_client + yield auth0_client + app.dependency_overrides.clear() @pytest.fixture def mock_auth0_client(mocker): + """ + Fully mocked Auth0Client - use when we want to just patch the results + of Auth0 calls + """ mock_client = mocker.patch("auth0.client.Auth0Client") app.dependency_overrides[get_auth0_client] = lambda: mock_client yield mock_client @@ -255,3 +268,14 @@ def mock_email_service(aws_credentials): app.dependency_overrides[get_email_service] = lambda: email_service yield email_service app.dependency_overrides.clear() + + +def now_freeze_aware(tz=None): + from datetime import datetime # local import to ensure freezegun patches are seen + return datetime.now(tz) if tz else datetime.now() + + +@pytest.fixture(autouse=True, scope="session") +def freezegun_polyfactory_compat(): + # Use frozen time when freezegun is active, otherwise real time + BaseFactory.add_provider(datetime, now_freeze_aware) From ebe7f6d5c4c82e2537b685a11234671e1ee2bf12 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Tue, 12 Aug 2025 16:23:27 +1000 Subject: [PATCH 12/23] Clean up auth0 and other fixtures in tests --- tests/db/test_models.py | 18 +++++++-------- tests/test_auth0_client.py | 28 ++++++++++++------------ tests/test_bpa_register.py | 20 +++-------------- tests/test_galaxy.py | 45 +++++++------------------------------- 4 files changed, 34 insertions(+), 77 deletions(-) diff --git a/tests/db/test_models.py b/tests/db/test_models.py index 36fc4775..dfeb2077 100644 --- a/tests/db/test_models.py +++ b/tests/db/test_models.py @@ -225,18 +225,18 @@ def test_create_auth0_role(test_db_session): @respx.mock -def test_create_auth0_role_by_name(test_db_session, auth0_client): +def test_create_auth0_role_by_name(test_db_session, test_auth0_client): """ Test when can create an auth0 role by name, looking up the role in Auth0 first """ role_data = RoleDataFactory.build(name="biocommons/role/tsi/admin") - respx.get("https://auth0.example.com/api/v2/roles", params={"name_filter": ANY}).mock( + respx.get(f"https://{test_auth0_client.domain}/api/v2/roles", params={"name_filter": ANY}).mock( return_value=Response(200, json=[role_data.model_dump(mode="json")]) ) Auth0Role.get_or_create_by_name( name=role_data.name, session=test_db_session, - auth0_client=auth0_client + auth0_client=test_auth0_client ) role_from_db = test_db_session.exec( select(Auth0Role).where(Auth0Role.id == role_data.id) @@ -256,7 +256,7 @@ def test_get_or_create_auth0_role_existing(test_db_session, mock_auth0_client, p @respx.mock -def test_create_auth0_role_by_id(test_db_session, auth0_client): +def test_create_auth0_role_by_id(test_db_session, test_auth0_client): """ Test when can create an auth0 role by id, looking up the role in Auth0 first """ @@ -267,7 +267,7 @@ def test_create_auth0_role_by_id(test_db_session, auth0_client): Auth0Role.get_or_create_by_id( auth0_id=role_data.id, session=test_db_session, - auth0_client=auth0_client + auth0_client=test_auth0_client ) role_from_db = test_db_session.exec( select(Auth0Role).where(Auth0Role.id == role_data.id) @@ -296,7 +296,7 @@ def test_create_biocommons_group(test_db_session, persistent_factories): @respx.mock -def test_group_membership_grant_auth0_role(auth0_client, persistent_factories): +def test_group_membership_grant_auth0_role(test_auth0_client, persistent_factories): group = BiocommonsGroupFactory.create_sync(group_id="biocommons/group/tsi", admin_roles=[]) user = BiocommonsUserFactory.create_sync(group_memberships=[]) role_data = RoleDataFactory.build(name="biocommons/group/tsi") @@ -307,19 +307,19 @@ def test_group_membership_grant_auth0_role(auth0_client, persistent_factories): params={"name_filter": group.group_id} ).respond(status_code=200, json=[role_data.model_dump(mode="json")]) route = respx.post(f"https://auth0.example.com/api/v2/users/{user.id}/roles").respond(status_code=200) - result = membership_request.grant_auth0_role(auth0_client) + result = membership_request.grant_auth0_role(test_auth0_client) assert result assert route.called @pytest.mark.parametrize("status", ["pending", "revoked"]) @respx.mock -def test_group_membership_grant_auth0_role_not_approved(status, auth0_client, persistent_factories): +def test_group_membership_grant_auth0_role_not_approved(status, test_auth0_client, persistent_factories): group = BiocommonsGroupFactory.create_sync(group_id="biocommons/group/tsi", admin_roles=[]) user = Auth0UserDataFactory.build() membership_request = GroupMembershipFactory.create_sync(group=group, user_id=user.user_id, approval_status=status) with pytest.raises(ValueError): - membership_request.grant_auth0_role(auth0_client) + membership_request.grant_auth0_role(test_auth0_client) def test_group_membership_save_with_history(test_db_session): diff --git a/tests/test_auth0_client.py b/tests/test_auth0_client.py index a97dbe2d..c1d1e20d 100644 --- a/tests/test_auth0_client.py +++ b/tests/test_auth0_client.py @@ -11,26 +11,26 @@ @respx.mock -def test_get_users_no_pagination(auth0_client): +def test_get_users_no_pagination(test_auth0_client): user = Auth0UserDataFactory.build() route = respx.get("https://auth0.example.com/api/v2/users").mock( return_value=Response(200, json=[user.model_dump(mode="json")]) ) - result = auth0_client.get_users() + result = test_auth0_client.get_users() assert route.called assert result[0].model_dump(mode="json") == user.model_dump(mode="json") @respx.mock -def test_get_users_with_pagination(auth0_client): +def test_get_users_with_pagination(test_auth0_client): user = Auth0UserDataFactory.build() route = respx.get("https://auth0.example.com/api/v2/users").respond( 200, json=[user.model_dump(mode="json")] ) - result = auth0_client.get_users(page=2, per_page=25) + result = test_auth0_client.get_users(page=2, per_page=25) # Validate the actual request request = route.calls[0].request @@ -41,14 +41,14 @@ def test_get_users_with_pagination(auth0_client): @respx.mock -def test_get_user_by_id(auth0_client): +def test_get_user_by_id(test_auth0_client): user_id = "auth0|789" user = Auth0UserDataFactory.build(user_id=user_id) route = respx.get(f"https://auth0.example.com/api/v2/users/{user_id}").mock( return_value=Response(200, json=user.model_dump(mode="json")) ) - result = auth0_client.get_user(user_id) + result = test_auth0_client.get_user(user_id) assert route.called assert result.model_dump(mode="json") == user.model_dump(mode="json") @@ -63,13 +63,13 @@ def test_get_user_by_id(auth0_client): ] ) @respx.mock -def test_search_users_methods(auth0_client, method, query): +def test_search_users_methods(test_auth0_client, method, query): user = Auth0UserDataFactory.build() route = respx.get("https://auth0.example.com/api/v2/users").respond( 200, json=[user.model_dump(mode="json")] ) - result = getattr(auth0_client, method)(page=3, per_page=50) + result = getattr(test_auth0_client, method)(page=3, per_page=50) assert route.called request = route.calls[0].request @@ -81,7 +81,7 @@ def test_search_users_methods(auth0_client, method, query): @respx.mock -def test_get_role_users(auth0_client): +def test_get_role_users(test_auth0_client): """ Test we can get users for a role from Auth0 API """ @@ -90,7 +90,7 @@ def test_get_role_users(auth0_client): route = respx.get(f"https://auth0.example.com/api/v2/roles/{role_id}/users").respond( 200, json=[u.model_dump(mode="json") for u in users] ) - result = auth0_client.get_role_users(role_id) + result = test_auth0_client.get_role_users(role_id) assert route.called assert len(result) == 3 for user in result: @@ -99,7 +99,7 @@ def test_get_role_users(auth0_client): @respx.mock -def test_get_all_role_users(auth0_client): +def test_get_all_role_users(test_auth0_client): """ Test we can get all users for a role from Auth0 API, automatically running through multiple pages if necessary. @@ -112,21 +112,21 @@ def test_get_all_role_users(auth0_client): side_effect=[Response(200, json=batch1.model_dump(mode="json")), Response(200, json=batch2.model_dump(mode="json"))] ) - result = auth0_client.get_all_role_users(role_id) + result = test_auth0_client.get_all_role_users(role_id) assert route.called assert route.call_count == 2 assert len(result) == 150 @respx.mock -def test_add_roles_to_user(auth0_client): +def test_add_roles_to_user(test_auth0_client): """ Test we can add roles to a user in Auth0 API """ user_id = random_auth0_id() role_id = random_auth0_role_id() route = respx.post(f"https://auth0.example.com/api/v2/users/{user_id}/roles").respond(204) - auth0_client.add_roles_to_user(user_id, role_id) + test_auth0_client.add_roles_to_user(user_id, role_id) assert route.called call_data = route.calls[0].request.content # Check role_id is passed as a list diff --git a/tests/test_bpa_register.py b/tests/test_bpa_register.py index 66a04d0c..d2ec31b2 100644 --- a/tests/test_bpa_register.py +++ b/tests/test_bpa_register.py @@ -3,9 +3,7 @@ import httpx import pytest -from auth0.client import get_auth0_client from db.models import BiocommonsUser -from main import app from schemas import Service from schemas.biocommons import BiocommonsRegisterData from tests.datagen import ( @@ -28,17 +26,6 @@ def valid_registration_data(): ).model_dump() -@pytest.fixture -def override_auth0_client(mocker): - def override_auth0_client(): - return mock_client - - mock_client = mocker.patch("routers.utils.Auth0Client")() - app.dependency_overrides[get_auth0_client] = override_auth0_client - yield mock_client - app.dependency_overrides.clear() - - def test_to_biocommons_register_data(valid_registration_data): bpa_data = BPARegistrationDataFactory.build() bpa_service = Service( @@ -58,7 +45,7 @@ def test_to_biocommons_register_data(valid_registration_data): def test_successful_registration( test_client_with_email, mocker, valid_registration_data, - override_auth0_client, mock_auth0_client, test_db_session + mock_auth0_client, test_db_session ): """Test successful user registration with BPA service""" test_client = test_client_with_email @@ -128,7 +115,7 @@ def test_service_and_resources_have_updated_by_system(): def test_registration_duplicate_user( - test_client, valid_registration_data, override_auth0_client, mock_auth0_client + test_client, valid_registration_data, mock_auth0_client ): """Test registration with duplicate user""" error = httpx.HTTPStatusError( @@ -241,8 +228,7 @@ def test_registration_email_format(test_client, valid_registration_data): def test_all_organizations_selected( test_client_with_email, mock_settings, - mocker, - override_auth0_client, + mocker, mock_auth0_client, valid_registration_data, ): diff --git a/tests/test_galaxy.py b/tests/test_galaxy.py index fa202a41..b7b43697 100644 --- a/tests/test_galaxy.py +++ b/tests/test_galaxy.py @@ -1,10 +1,8 @@ from datetime import UTC, datetime, timedelta -from unittest.mock import MagicMock import pytest from fastapi import HTTPException from freezegun import freeze_time -from httpx import Response from jose import jwt from pydantic import ValidationError @@ -13,25 +11,11 @@ from schemas.biocommons import BiocommonsRegisterData from schemas.galaxy import GalaxyRegistrationData from tests.datagen import ( - AccessTokenPayloadFactory, Auth0UserDataFactory, GalaxyRegistrationDataFactory, ) -@pytest.fixture -def mock_auth_token(mocker): - """Fixture to mock authentication token""" - token = AccessTokenPayloadFactory.build( - sub="auth0|123456789", - biocommons_roles=["acdc/indexd_admin"], - ) - mocker.patch("auth.validator.verify_jwt", return_value=token) - mocker.patch("auth.management.get_management_token", return_value="mock_token") - mocker.patch("routers.galaxy_register.get_management_token", return_value="mock_token") - return token - - def test_galaxy_registration_data_password_match(): with pytest.raises(ValidationError, match="Passwords do not match"): GalaxyRegistrationData(email="user@example.com", @@ -107,7 +91,7 @@ def test_to_biocommons_register_data_empty_fields(): @freeze_time("2025-01-01") -def test_register(mocker, mock_auth_token, mock_settings, test_client): +def test_register(mock_settings, test_client, mock_auth0_client): """ Try to test our register endpoint. Since we don't want to call an actual Auth0 API, test that: @@ -115,43 +99,30 @@ def test_register(mocker, mock_auth_token, mock_settings, test_client): * The post request is made with the correct data * The response from our endpoint looks like we expect """ - mock_resp = MagicMock() - # Dummy user data: doesn't currently resemble response from Auth0 - mock_resp.json.return_value = {"user_id": "abc123"} - mock_resp.status_code = 201 - mock_post = mocker.patch("httpx.post", return_value=mock_resp) + auth0_user_data = Auth0UserDataFactory.build() + mock_auth0_client.create_user.return_value = auth0_user_data user_data = GalaxyRegistrationDataFactory.build() token_resp = test_client.get("/galaxy/register/get-registration-token") headers = {"registration-token": token_resp.json()["token"]} resp = test_client.post("/galaxy/register", json=user_data.model_dump(), headers=headers) assert resp.status_code == 200 assert resp.json()["message"] == "User registered successfully" - assert resp.json()["user"] == {"user_id": "abc123"} + assert resp.json()["user"] == auth0_user_data.model_dump(mode="json") - url = f"https://{mock_settings.auth0_domain}/api/v2/users" - headers = {"Authorization": "Bearer mock_token"} register_data = BiocommonsRegisterData.from_galaxy_registration(user_data) - mock_post.assert_called_once_with( - url, - json=register_data.model_dump(mode="json", exclude_none=True), - headers=headers - ) + mock_auth0_client.create_user.assert_called_once_with(register_data) @pytest.mark.respx(base_url="https://mock-domain") -def test_register_json_types(respx_mock, mock_auth_token, mock_settings, test_client, mock_galaxy_client): +def test_register_json_types(mock_auth0_client, mock_settings, test_client, mock_galaxy_client): """ Test how we handle datetimes in the response data: if we don't use model_dump(mode="json") when providing json data, we can get errors """ - url = f"https://{mock_settings.auth0_domain}/api/v2/users" # Generate user data to be returned in the response # (doesn't have to match the registration data for now) - user = Auth0UserDataFactory.build(created_at=datetime.now(UTC)) - respx_mock.post(url).mock(return_value=Response( - status_code=201, - json=user.model_dump(mode="json")) - ) + auth0_user_data = Auth0UserDataFactory.build() + mock_auth0_client.create_user.return_value = auth0_user_data user_data = GalaxyRegistrationDataFactory.build() token_resp = test_client.get("/galaxy/register/get-registration-token") headers = {"registration-token": token_resp.json()["token"]} From 0fe8bee659373d543e91ebbd07d5f95d9da218d6 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 09:52:38 +1000 Subject: [PATCH 13/23] Make sure test_db_session is used for tests that use the DB --- tests/test_bpa_register.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_bpa_register.py b/tests/test_bpa_register.py index d2ec31b2..53061720 100644 --- a/tests/test_bpa_register.py +++ b/tests/test_bpa_register.py @@ -227,6 +227,7 @@ def test_registration_email_format(test_client, valid_registration_data): def test_all_organizations_selected( test_client_with_email, + test_db_session, mock_settings, mocker, mock_auth0_client, From 4178ab8c717c19a1685e758695339a0ea758e47a Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 10:22:37 +1000 Subject: [PATCH 14/23] Rework DB setup to make sure we don't accidentally use it in tests --- db/admin.py | 4 ++-- db/setup.py | 17 +++++++++++++---- main.py | 2 +- tests/conftest.py | 12 +++++++++--- 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/db/admin.py b/db/admin.py index 219af091..b0731221 100644 --- a/db/admin.py +++ b/db/admin.py @@ -16,7 +16,7 @@ GroupMembership, GroupMembershipHistory, ) -from db.setup import engine +from db.setup import get_engine def setup_oauth(): @@ -114,7 +114,7 @@ def __init__(self, app: FastAPI, secret_key: str): self.auth0_client = setup_oauth() self.admin = Admin( app, - engine=engine, + engine=get_engine(), base_url="/db_admin", authentication_backend=AdminAuth(secret_key=secret_key, auth0_client=self.auth0_client), title="AAI Backend Admin" diff --git a/db/setup.py b/db/setup.py index a3a17cf8..605650f8 100644 --- a/db/setup.py +++ b/db/setup.py @@ -9,6 +9,17 @@ log = logging.getLogger('uvicorn.error') +# Set engine as None initially so it's not created on import +_engine = None + + +def get_engine(): + global _engine + if _engine is None: + db_url, db_connect_args = get_db_config() + _engine = create_engine(db_url, connect_args=db_connect_args) + return _engine + def get_db_config() -> Tuple[str, dict]: """ @@ -40,19 +51,17 @@ def get_db_config() -> Tuple[str, dict]: return db_url, connect_args -DB_URL, db_connect_args = get_db_config() -engine = create_engine(DB_URL, connect_args=db_connect_args) - - def create_db_and_tables(): # NOTE: we only do this in dev (with sqlite). # For production, we manage the DB schema with alembic db_url, connect_args = get_db_config() if db_url.startswith("sqlite://"): + engine = get_engine() log.info("Automatically creating DB tables for sqlite") BaseModel.metadata.create_all(engine) def get_db_session(): + engine = get_engine() with Session(engine) as session: yield session diff --git a/main.py b/main.py index c06c2b4e..016ce1ff 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,6 @@ # This has to be imported even if unused from db import models # noqa: F401 from db.admin import DatabaseAdmin -from db.setup import create_db_and_tables from routers import admin, biocommons_groups, bpa_register, galaxy_register, user, utils # Load .env to get CORS_ALLOWED_ORIGINS. @@ -26,6 +25,7 @@ async def lifespan(app: FastAPI): # NOTE: we only create the database and tables automatically in development: # we assume that if the DB is an sqlite DB, we are in dev. + from db.setup import create_db_and_tables create_db_and_tables() DatabaseAdmin.setup(app=app, secret_key=SECRET_KEY) yield diff --git a/tests/conftest.py b/tests/conftest.py index 476298f1..92812692 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,8 +14,6 @@ from auth.validator import get_current_user from auth0.client import Auth0Client, get_auth0_client from config import Settings, get_settings -from db.core import BaseModel -from db.setup import get_db_session from galaxy.client import GalaxyClient, get_galaxy_client from galaxy.config import GalaxySettings, get_galaxy_settings from main import app @@ -31,6 +29,7 @@ @pytest.fixture(scope="function") def test_db_engine(): from db import models # noqa: F401 + from db.core import BaseModel engine = create_engine( # Use in-memory DB by default "sqlite://", @@ -57,6 +56,8 @@ def test_db_session(session): """ Override the get_db_session dependency to return the test DB. """ + from db.setup import get_db_session + def get_db_session_override(): yield session app.dependency_overrides[get_db_session] = get_db_session_override @@ -79,7 +80,12 @@ def get_galaxy_settings_no_env_file(): app.dependency_overrides[get_galaxy_settings] = get_galaxy_settings_no_env_file # Make sure we always use in-memory DB for test DB os.environ.pop("DB_HOST", None) - os.environ["DB_URL"] = "sqlite://" + os.environ["DB_URL"] = "sqlite:///file:dummy_db?mode=memory&uri=true" + + +@pytest.fixture(autouse=True) +def ignore_db_config(mocker): + mocker.patch("db.setup.get_db_config", return_value=("sqlite:///file:dummy_db?mode=memory&uri=true", {})) @pytest.fixture(autouse=True) From 223072a4b3bceb709774715497b84fa9dcedf3f1 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 10:27:08 +1000 Subject: [PATCH 15/23] Add test_db_session to tests that need it --- tests/test_galaxy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_galaxy.py b/tests/test_galaxy.py index b7b43697..578841c4 100644 --- a/tests/test_galaxy.py +++ b/tests/test_galaxy.py @@ -91,7 +91,7 @@ def test_to_biocommons_register_data_empty_fields(): @freeze_time("2025-01-01") -def test_register(mock_settings, test_client, mock_auth0_client): +def test_register(mock_settings, test_client, mock_auth0_client, test_db_session): """ Try to test our register endpoint. Since we don't want to call an actual Auth0 API, test that: @@ -114,7 +114,7 @@ def test_register(mock_settings, test_client, mock_auth0_client): @pytest.mark.respx(base_url="https://mock-domain") -def test_register_json_types(mock_auth0_client, mock_settings, test_client, mock_galaxy_client): +def test_register_json_types(mock_auth0_client, mock_settings, test_client, mock_galaxy_client, test_db_session): """ Test how we handle datetimes in the response data: if we don't use model_dump(mode="json") when providing json data, we can get errors From a838cd7498e2b28adbe63d7e5ed0ade44cddea73 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 13:48:49 +1000 Subject: [PATCH 16/23] Method to add platform membership for user --- db/models.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/db/models.py b/db/models.py index 356da23b..9d4111f0 100644 --- a/db/models.py +++ b/db/models.py @@ -75,6 +75,16 @@ def get_or_create(cls, auth0_id: str, db_session: Session, auth0_client: Auth0Cl db_session.commit() return user + def add_platform_membership(self, platform: PlatformEnum, auto_approve: bool = False) -> "PlatformMembership": + membership = PlatformMembership( + platform_id=platform, + user=self, + approval_status=ApprovalStatusEnum.APPROVED if auto_approve else ApprovalStatusEnum.PENDING, + updated_by=None, + ) + self.platform_memberships.append(membership) + return membership + class PlatformMembership(BaseModel, table=True): __table_args__ = ( From fd34f508d5725f7e75e0fffa9476e182e6cb6fa7 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 13:57:57 +1000 Subject: [PATCH 17/23] Add galaxy membership during registration and test it --- routers/galaxy_register.py | 9 ++++++++- tests/test_galaxy.py | 26 +++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/routers/galaxy_register.py b/routers/galaxy_register.py index 3fb3d88b..60aa227b 100644 --- a/routers/galaxy_register.py +++ b/routers/galaxy_register.py @@ -8,7 +8,7 @@ from auth0.client import Auth0Client, get_auth0_client from config import Settings, get_settings -from db.models import BiocommonsUser +from db.models import BiocommonsUser, PlatformEnum from db.setup import get_db_session from galaxy.client import GalaxyClient, get_galaxy_client from register.tokens import create_registration_token, verify_registration_token @@ -57,8 +57,15 @@ def register( auth0_user_data = auth0_client.create_user(user_data) except httpx.HTTPStatusError as e: raise HTTPException(status_code=e.response.status_code, detail=f'Registration failed: {e}') + # Add to database and record Galaxy membership logger.info("Adding user to DB") db_user = BiocommonsUser.from_auth0_data(data=auth0_user_data) + galaxy_membership = db_user.add_platform_membership( + platform=PlatformEnum.GALAXY, + db_session=db_session, + auto_approve=True + ) db_session.add(db_user) + db_session.add(galaxy_membership) db_session.commit() return {"message": "User registered successfully", "user": auth0_user_data.model_dump(mode="json")} diff --git a/tests/test_galaxy.py b/tests/test_galaxy.py index 578841c4..bbece56d 100644 --- a/tests/test_galaxy.py +++ b/tests/test_galaxy.py @@ -5,8 +5,15 @@ from freezegun import freeze_time from jose import jwt from pydantic import ValidationError +from sqlmodel import select import register +from db.models import ( + BiocommonsUser, + PlatformEnum, + PlatformMembership, + PlatformMembershipHistory, +) from register.tokens import verify_registration_token from schemas.biocommons import BiocommonsRegisterData from schemas.galaxy import GalaxyRegistrationData @@ -98,6 +105,8 @@ def test_register(mock_settings, test_client, mock_auth0_client, test_db_session * The post request is made with the correct data * The response from our endpoint looks like we expect + * A user is created in the DB + * A PlatformMembership record is created for Galaxy """ auth0_user_data = Auth0UserDataFactory.build() mock_auth0_client.create_user.return_value = auth0_user_data @@ -108,9 +117,24 @@ def test_register(mock_settings, test_client, mock_auth0_client, test_db_session assert resp.status_code == 200 assert resp.json()["message"] == "User registered successfully" assert resp.json()["user"] == auth0_user_data.model_dump(mode="json") - + # Check data used to register is correct register_data = BiocommonsRegisterData.from_galaxy_registration(user_data) mock_auth0_client.create_user.assert_called_once_with(register_data) + # Check user is created in the database with membership + db_user = test_db_session.get(BiocommonsUser, auth0_user_data.user_id) + assert db_user is not None + assert db_user.id == auth0_user_data.user_id + galaxy_membership = test_db_session.exec(select(PlatformMembership).where( + PlatformMembership.user_id == db_user.id, + PlatformMembership.platform_id == PlatformEnum.GALAXY.value + )).one() + assert galaxy_membership.approval_status == "approved" + membership_history = test_db_session.exec(select(PlatformMembershipHistory).where( + PlatformMembershipHistory.user_id == db_user.id, + PlatformMembershipHistory.platform_id == PlatformEnum.GALAXY.value + )).one() + assert membership_history.approval_status == "approved" + @pytest.mark.respx(base_url="https://mock-domain") From 0db00c6e3cd2cfbcfa3324cff9b4fedc7fd164d1 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 13:58:28 +1000 Subject: [PATCH 18/23] Save history when adding platform membership to user --- db/models.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/db/models.py b/db/models.py index 9d4111f0..fe08fd25 100644 --- a/db/models.py +++ b/db/models.py @@ -75,13 +75,15 @@ def get_or_create(cls, auth0_id: str, db_session: Session, auth0_client: Auth0Cl db_session.commit() return user - def add_platform_membership(self, platform: PlatformEnum, auto_approve: bool = False) -> "PlatformMembership": + def add_platform_membership(self, platform: PlatformEnum, db_session: Session, auto_approve: bool = False) -> "PlatformMembership": membership = PlatformMembership( platform_id=platform, user=self, approval_status=ApprovalStatusEnum.APPROVED if auto_approve else ApprovalStatusEnum.PENDING, updated_by=None, ) + db_session.add(membership) + membership.save_history(db_session) self.platform_memberships.append(membership) return membership @@ -101,6 +103,17 @@ class PlatformMembership(BaseModel, table=True): updated_by_id: str | None = Field(foreign_key="biocommons_user.id", nullable=True) updated_by: "BiocommonsUser" = Relationship(sa_relationship_kwargs={"foreign_keys": "PlatformMembership.updated_by_id",}) + def save_history(self, session: Session) -> 'PlatformMembershipHistory': + history = PlatformMembershipHistory( + platform_id=self.platform_id, + user=self.user, + approval_status=self.approval_status, + updated_at=self.updated_at, + updated_by=self.updated_by, + ) + session.add(history) + return history + From fa3da6990bc25058cb1d40ed315796d334b8fdf1 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 13:59:56 +1000 Subject: [PATCH 19/23] Update comment on test --- tests/test_galaxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_galaxy.py b/tests/test_galaxy.py index bbece56d..ed4a5c1f 100644 --- a/tests/test_galaxy.py +++ b/tests/test_galaxy.py @@ -120,7 +120,7 @@ def test_register(mock_settings, test_client, mock_auth0_client, test_db_session # Check data used to register is correct register_data = BiocommonsRegisterData.from_galaxy_registration(user_data) mock_auth0_client.create_user.assert_called_once_with(register_data) - # Check user is created in the database with membership + # Check user is created in the database with membership and history db_user = test_db_session.get(BiocommonsUser, auth0_user_data.user_id) assert db_user is not None assert db_user.id == auth0_user_data.user_id From ac0ac002483aab97bd5bfbf3ff5bc1fb2f469d2b Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 15:23:56 +1000 Subject: [PATCH 20/23] Clean up user record creation in Galaxy registration --- routers/galaxy_register.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/routers/galaxy_register.py b/routers/galaxy_register.py index 60aa227b..be05f707 100644 --- a/routers/galaxy_register.py +++ b/routers/galaxy_register.py @@ -12,7 +12,7 @@ from db.setup import get_db_session from galaxy.client import GalaxyClient, get_galaxy_client from register.tokens import create_registration_token, verify_registration_token -from schemas.biocommons import BiocommonsRegisterData +from schemas.biocommons import Auth0UserData, BiocommonsRegisterData from schemas.galaxy import GalaxyRegistrationData logger = logging.getLogger(__name__) @@ -59,13 +59,18 @@ def register( raise HTTPException(status_code=e.response.status_code, detail=f'Registration failed: {e}') # Add to database and record Galaxy membership logger.info("Adding user to DB") + _create_galaxy_user_record(auth0_user_data, db_session) + return {"message": "User registered successfully", "user": auth0_user_data.model_dump(mode="json")} + + +def _create_galaxy_user_record(auth0_user_data: Auth0UserData, session: Session) -> BiocommonsUser: db_user = BiocommonsUser.from_auth0_data(data=auth0_user_data) galaxy_membership = db_user.add_platform_membership( platform=PlatformEnum.GALAXY, - db_session=db_session, + db_session=session, auto_approve=True ) - db_session.add(db_user) - db_session.add(galaxy_membership) - db_session.commit() - return {"message": "User registered successfully", "user": auth0_user_data.model_dump(mode="json")} + session.add(db_user) + session.add(galaxy_membership) + session.commit() + return db_user From fbb61039450af414ac25a9ea2a65e68b6d811bc6 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 15:32:33 +1000 Subject: [PATCH 21/23] Create platform membership for BPA when creating user + test --- routers/bpa_register.py | 21 ++++++++++++++++----- tests/test_bpa_register.py | 19 ++++++++++++++++++- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/routers/bpa_register.py b/routers/bpa_register.py index 96106e26..338aa991 100644 --- a/routers/bpa_register.py +++ b/routers/bpa_register.py @@ -9,9 +9,9 @@ from auth.ses import EmailService from auth0.client import Auth0Client, get_auth0_client from config import Settings, get_settings -from db.models import BiocommonsUser +from db.models import BiocommonsUser, PlatformEnum from db.setup import get_db_session -from schemas.biocommons import BiocommonsRegisterData +from schemas.biocommons import Auth0UserData, BiocommonsRegisterData from schemas.bpa import BPARegistrationRequest from schemas.service import Resource, Service @@ -102,9 +102,7 @@ async def register_bpa_user( auth0_user_data = auth0_client.create_user(user_data) logger.info("Adding user to DB") - db_user = BiocommonsUser.from_auth0_data(data=auth0_user_data) - db_session.add(db_user) - db_session.commit() + _create_bpa_user_record(auth0_user_data, db_session) if bpa_resources and settings.send_email: background_tasks.add_task(send_approval_email, registration, bpa_resources) @@ -117,3 +115,16 @@ async def register_bpa_user( raise HTTPException( status_code=500, detail=f"Failed to register user: {str(e)}" ) + + +def _create_bpa_user_record(auth0_user_data: Auth0UserData, session: Session) -> BiocommonsUser: + db_user = BiocommonsUser.from_auth0_data(data=auth0_user_data) + bpa_membership = db_user.add_platform_membership( + platform=PlatformEnum.BPA_DATA_PORTAL, + db_session=session, + auto_approve=True + ) + session.add(db_user) + session.add(bpa_membership) + session.commit() + return db_user diff --git a/tests/test_bpa_register.py b/tests/test_bpa_register.py index 53061720..cec5b1df 100644 --- a/tests/test_bpa_register.py +++ b/tests/test_bpa_register.py @@ -2,8 +2,14 @@ import httpx import pytest +from sqlmodel import select -from db.models import BiocommonsUser +from db.models import ( + BiocommonsUser, + PlatformEnum, + PlatformMembership, + PlatformMembershipHistory, +) from schemas import Service from schemas.biocommons import BiocommonsRegisterData from tests.datagen import ( @@ -64,6 +70,17 @@ def test_successful_registration( db_user = test_db_session.get(BiocommonsUser, user_id) assert db_user is not None assert db_user.id == user_id + # Check platform membership and history is created + bpa_membership = test_db_session.exec(select(PlatformMembership).where( + PlatformMembership.user_id == db_user.id, + PlatformMembership.platform_id == PlatformEnum.BPA_DATA_PORTAL.value + )).one() + assert bpa_membership.approval_status == "approved" + membership_history = test_db_session.exec(select(PlatformMembershipHistory).where( + PlatformMembershipHistory.user_id == db_user.id, + PlatformMembershipHistory.platform_id == PlatformEnum.BPA_DATA_PORTAL.value + )).one() + assert membership_history.approval_status == "approved" called_data = mock_auth0_client.create_user.call_args[0][0] assert called_data.email == valid_registration_data["email"] From a7e603e0f3a2d25621638fbc54612c4aacb653d0 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 15:56:36 +1000 Subject: [PATCH 22/23] Make sure we dump to JSON when returning user data --- routers/bpa_register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/routers/bpa_register.py b/routers/bpa_register.py index 338aa991..1d6eacfb 100644 --- a/routers/bpa_register.py +++ b/routers/bpa_register.py @@ -107,7 +107,7 @@ async def register_bpa_user( if bpa_resources and settings.send_email: background_tasks.add_task(send_approval_email, registration, bpa_resources) - return {"message": "User registered successfully", "user": auth0_user_data} + return {"message": "User registered successfully", "user": auth0_user_data.model_dump(mode="json")} except HTTPStatusError as e: raise HTTPException(status_code=e.response.status_code, detail=e.response.text) From 176964cbe3c08effd46d41333f25c0a992b73375 Mon Sep 17 00:00:00 2001 From: marius-mather Date: Wed, 13 Aug 2025 16:00:57 +1000 Subject: [PATCH 23/23] Don't double up on adding platform membership to user --- db/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/db/models.py b/db/models.py index fe08fd25..77b6c2ae 100644 --- a/db/models.py +++ b/db/models.py @@ -84,7 +84,6 @@ def add_platform_membership(self, platform: PlatformEnum, db_session: Session, a ) db_session.add(membership) membership.save_history(db_session) - self.platform_memberships.append(membership) return membership